File size: 31,443 Bytes
66407c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0ddc582b",
   "metadata": {},
   "source": [
    "# VeRL Ray API Tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71fe3b94",
   "metadata": {},
   "source": [
    "## Chapter 1: Ray Basics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "id": "1347d381",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "e75b9d44",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "import ray\n",
    "import torch\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "2e90ae00",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-01 17:27:19,132\tINFO worker.py:1752 -- Started a local Ray instance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9cc9d2ccbdfb48918c8fd6cd13a0807a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<div class=\"lm-Widget p-Widget lm-Panel p-Panel jp-Cell-outputWrapper\">\n",
       "    <div style=\"margin-left: 50px;display: flex;flex-direction: row;align-items: center\">\n",
       "        <div class=\"jp-RenderedHTMLCommon\" style=\"display: flex; flex-direction: row;\">\n",
       "  <svg viewBox=\"0 0 567 224\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\" style=\"height: 3em;\">\n",
       "    <g clip-path=\"url(#clip0_4338_178347)\">\n",
       "        <path d=\"M341.29 165.561H355.29L330.13 129.051C345.63 123.991 354.21 112.051 354.21 94.2307C354.21 71.3707 338.72 58.1807 311.88 58.1807H271V165.561H283.27V131.661H311.8C314.25 131.661 316.71 131.501 319.01 131.351L341.25 165.561H341.29ZM283.29 119.851V70.0007H311.82C331.3 70.0007 342.34 78.2907 342.34 94.5507C342.34 111.271 331.34 119.861 311.82 119.861L283.29 119.851ZM451.4 138.411L463.4 165.561H476.74L428.74 58.1807H416L367.83 165.561H380.83L392.83 138.411H451.4ZM446.19 126.601H398L422 72.1407L446.24 126.601H446.19ZM526.11 128.741L566.91 58.1807H554.35L519.99 114.181L485.17 58.1807H472.44L514.01 129.181V165.541H526.13V128.741H526.11Z\" fill=\"var(--jp-ui-font-color0)\"/>\n",
       "        <path d=\"M82.35 104.44C84.0187 97.8827 87.8248 92.0678 93.1671 87.9146C98.5094 83.7614 105.083 81.5067 111.85 81.5067C118.617 81.5067 125.191 83.7614 130.533 87.9146C135.875 92.0678 139.681 97.8827 141.35 104.44H163.75C164.476 101.562 165.622 98.8057 167.15 96.2605L127.45 56.5605C121.071 60.3522 113.526 61.6823 106.235 60.3005C98.9443 58.9187 92.4094 54.9203 87.8602 49.0574C83.3109 43.1946 81.0609 35.8714 81.5332 28.4656C82.0056 21.0599 85.1679 14.0819 90.4252 8.8446C95.6824 3.60726 102.672 0.471508 110.08 0.0272655C117.487 -0.416977 124.802 1.86091 130.647 6.4324C136.493 11.0039 140.467 17.5539 141.821 24.8501C143.175 32.1463 141.816 39.6859 138 46.0505L177.69 85.7505C182.31 82.9877 187.58 81.4995 192.962 81.4375C198.345 81.3755 203.648 82.742 208.33 85.3976C213.012 88.0532 216.907 91.9029 219.616 96.5544C222.326 101.206 223.753 106.492 223.753 111.875C223.753 117.258 222.326 122.545 219.616 127.197C216.907 131.848 213.012 135.698 208.33 138.353C203.648 141.009 198.345 142.375 192.962 142.313C187.58 142.251 182.31 140.763 177.69 138L138 177.7C141.808 184.071 143.155 191.614 141.79 198.91C140.424 206.205 136.44 212.75 130.585 217.313C124.731 221.875 117.412 224.141 110.004 223.683C102.596 223.226 95.6103 220.077 90.3621 214.828C85.1139 209.58 81.9647 202.595 81.5072 195.187C81.0497 187.779 83.3154 180.459 87.878 174.605C92.4405 168.751 98.9853 164.766 106.281 163.401C113.576 162.035 121.119 163.383 127.49 167.19L167.19 127.49C165.664 124.941 164.518 122.182 163.79 119.3H141.39C139.721 125.858 135.915 131.673 130.573 135.826C125.231 139.98 118.657 142.234 111.89 142.234C105.123 142.234 98.5494 139.98 93.2071 135.826C87.8648 131.673 84.0587 125.858 82.39 119.3H60C58.1878 126.495 53.8086 132.78 47.6863 136.971C41.5641 141.163 34.1211 142.972 26.7579 142.059C19.3947 141.146 12.6191 137.574 7.70605 132.014C2.79302 126.454 0.0813599 119.29 0.0813599 111.87C0.0813599 104.451 2.79302 97.2871 7.70605 91.7272C12.6191 86.1673 19.3947 82.5947 26.7579 81.6817C34.1211 80.7686 41.5641 82.5781 47.6863 86.7696C53.8086 90.9611 58.1878 97.2456 60 104.44H82.35ZM100.86 204.32C103.407 206.868 106.759 208.453 110.345 208.806C113.93 209.159 117.527 208.258 120.522 206.256C123.517 204.254 125.725 201.276 126.771 197.828C127.816 194.38 127.633 190.677 126.253 187.349C124.874 184.021 122.383 181.274 119.205 179.577C116.027 177.88 112.359 177.337 108.826 178.042C105.293 178.746 102.113 180.654 99.8291 183.44C97.5451 186.226 96.2979 189.718 96.3 193.32C96.2985 195.364 96.7006 197.388 97.4831 199.275C98.2656 201.163 99.4132 202.877 100.86 204.32ZM204.32 122.88C206.868 120.333 208.453 116.981 208.806 113.396C209.159 109.811 208.258 106.214 206.256 103.219C204.254 100.223 201.275 98.0151 197.827 96.97C194.38 95.9249 190.676 96.1077 187.348 97.4873C184.02 98.8669 181.274 101.358 179.577 104.536C177.879 107.714 177.337 111.382 178.041 114.915C178.746 118.448 180.653 121.627 183.439 123.911C186.226 126.195 189.717 127.443 193.32 127.44C195.364 127.443 197.388 127.042 199.275 126.259C201.163 125.476 202.878 124.328 204.32 122.88ZM122.88 19.4205C120.333 16.8729 116.981 15.2876 113.395 14.9347C109.81 14.5817 106.213 15.483 103.218 17.4849C100.223 19.4868 98.0146 22.4654 96.9696 25.9131C95.9245 29.3608 96.1073 33.0642 97.4869 36.3922C98.8665 39.7202 101.358 42.4668 104.535 44.1639C107.713 45.861 111.381 46.4036 114.914 45.6992C118.447 44.9949 121.627 43.0871 123.911 40.301C126.195 37.515 127.442 34.0231 127.44 30.4205C127.44 28.3772 127.038 26.3539 126.255 24.4664C125.473 22.5788 124.326 20.8642 122.88 19.4205ZM19.42 100.86C16.8725 103.408 15.2872 106.76 14.9342 110.345C14.5813 113.93 15.4826 117.527 17.4844 120.522C19.4863 123.518 22.4649 125.726 25.9127 126.771C29.3604 127.816 33.0638 127.633 36.3918 126.254C39.7198 124.874 42.4664 122.383 44.1635 119.205C45.8606 116.027 46.4032 112.359 45.6988 108.826C44.9944 105.293 43.0866 102.114 40.3006 99.8296C37.5145 97.5455 34.0227 96.2983 30.42 96.3005C26.2938 96.3018 22.337 97.9421 19.42 100.86ZM100.86 100.86C98.3125 103.408 96.7272 106.76 96.3742 110.345C96.0213 113.93 96.9226 117.527 98.9244 120.522C100.926 123.518 103.905 125.726 107.353 126.771C110.8 127.816 114.504 127.633 117.832 126.254C121.16 124.874 123.906 122.383 125.604 119.205C127.301 116.027 127.843 112.359 127.139 108.826C126.434 105.293 124.527 102.114 121.741 99.8296C118.955 97.5455 115.463 96.2983 111.86 96.3005C109.817 96.299 107.793 96.701 105.905 97.4835C104.018 98.2661 102.303 99.4136 100.86 100.86Z\" fill=\"#00AEEF\"/>\n",
       "    </g>\n",
       "    <defs>\n",
       "        <clipPath id=\"clip0_4338_178347\">\n",
       "            <rect width=\"566.93\" height=\"223.75\" fill=\"white\"/>\n",
       "        </clipPath>\n",
       "    </defs>\n",
       "  </svg>\n",
       "</div>\n",
       "\n",
       "        <table class=\"jp-RenderedHTMLCommon\" style=\"border-collapse: collapse;color: var(--jp-ui-font-color1);font-size: var(--jp-ui-font-size1);\">\n",
       "    <tr>\n",
       "        <td style=\"text-align: left\"><b>Python version:</b></td>\n",
       "        <td style=\"text-align: left\"><b>3.9.2</b></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "        <td style=\"text-align: left\"><b>Ray version:</b></td>\n",
       "        <td style=\"text-align: left\"><b>2.10.0</b></td>\n",
       "    </tr>\n",
       "    \n",
       "</table>\n",
       "\n",
       "    </div>\n",
       "</div>\n"
      ],
      "text/plain": [
       "RayContext(dashboard_url='', python_version='3.9.2', ray_version='2.10.0', ray_commit='09abba26b5bf2707639bb637c208d062a47b46f6')"
      ]
     },
     "execution_count": 146,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36m(GPUAccumulator pid=224400)\u001b[0m rank 0, value: tensor([1.], device='cuda:0')\n",
      "\u001b[36m(GPUAccumulator pid=225234)\u001b[0m rank 2, value: tensor([3.], device='cuda:0')\n",
      "\u001b[36m(GPUAccumulator pid=225607)\u001b[0m rank 0, value: tensor([2.], device='cuda:0')\n",
      "\u001b[36m(GPUAccumulator pid=226423)\u001b[0m rank 1, value: tensor([3.], device='cuda:0')\n",
      "\u001b[36m(GPUAccumulator pid=226857)\u001b[0m rank 3, value: tensor([6.], device='cuda:0')\n",
      "\u001b[36m(GPUAccumulatorDecorator pid=227475)\u001b[0m 10\n",
      "\u001b[36m(GPUAccumulatorDecorator pid=227475)\u001b[0m rank 0, value: tensor([10.], device='cuda:0')\n",
      "\u001b[36m(GPUAccumulatorDecorator pid=227655)\u001b[0m rank 1, value: tensor([11.], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# Build a local ray cluster. The head node and worker node are on this machine\n",
    "ray.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a127e4e4",
   "metadata": {},
   "source": [
    "Implement an Accumulator class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "20e7b9a3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "class Accumulator:\n",
    "    def __init__(self):\n",
    "        self.value = 0\n",
    "\n",
    "    def add(self, x):\n",
    "        self.value += x\n",
    "\n",
    "    def get_value(self):\n",
    "        return self.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "3b80098c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Instantiate an accumulator. Accumulator can be viewed as a process, acting as an RPC service.\n",
    "accumulator = Accumulator.remote()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "b14b1009",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "value_ref = accumulator.get_value.remote()  # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\n",
    "# Get the value\n",
    "value = ray.get(value_ref)\n",
    "print(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "513a84b3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    }
   ],
   "source": [
    "# Accumulate, then check the result.\n",
    "accumulator.add.remote(10)  # Similarly, the 'add' here will return immediately.\n",
    "new_value = ray.get(accumulator.get_value.remote())\n",
    "print(new_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c332fe0",
   "metadata": {},
   "source": [
    "## Chapter 2: Resource Pool and RayWorkerGroup\n",
    "In the previous example, it was a simple single-process worker. \n",
    "In this example, we implement a worker with a GPU and form a RayWorkerGroup. Within this RayWorkerGroup, we implement a simple operation of an accumulator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "04229afb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from verl.single_controller.base import Worker\n",
    "from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "0d0dbd58",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "resource_pool = RayResourcePool([4], use_gpu=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "68f6838a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "class GPUAccumulator(Worker):\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        # The initial value of each rank is the same as the rank\n",
    "        self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n",
    "\n",
    "    def add(self, x):\n",
    "        self.value += x\n",
    "        print(f\"rank {self.rank}, value: {self.value}\")\n",
    "        return self.value.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "23aad8fe",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([1.]), tensor([2.]), tensor([3.]), tensor([4.])]\n"
     ]
    }
   ],
   "source": [
    "# Each worker's initial value is its rank, and then each rank's value is incremented by 1, so the values obtained on each rank are [1, 2, 3, 4]\n",
    "class_with_args = RayClassWithInitArgs(cls=GPUAccumulator)\n",
    "worker_group = RayWorkerGroup(resource_pool, class_with_args)\n",
    "print(worker_group.execute_all_sync(\"add\", x=[1, 1, 1, 1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6705284",
   "metadata": {},
   "source": [
    "The principle of parameter passing: The input parameter is a list of length world_size, where each element in the list is dispatched respectively to each worker in the RayWorkerGroup. \n",
    "The return parameter is also a list, corresponding to the return value of each worker."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d25c2412",
   "metadata": {},
   "source": [
    "### GPU Resource Sharing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f74f6d24",
   "metadata": {},
   "source": [
    "RayWorkerGroups mapped to the same resource pool share the GPU. In this example, we implement three resource pools: the first occupies 4 GPUs, the second also occupies 4 GPUs, and the last occupies all 8 GPUs. Among them, the first resource pool reuses the resource pool mentioned above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "49f9c06f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create a new resource pool and then merge the newly created resource pool with the previous one.\n",
    "resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix=\"a\")\n",
    "resource_pool_merge = merge_resource_pool(resource_pool, resource_pool_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "05c2e305",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Establish a RayWorkerGroup on the newly created resource pool.\n",
    "worker_group_1 = RayWorkerGroup(resource_pool_1, class_with_args)\n",
    "worker_group_merge = RayWorkerGroup(resource_pool_merge, class_with_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "id": "6b9b13f4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([2.]), tensor([3.]), tensor([4.]), tensor([5.])]\n"
     ]
    }
   ],
   "source": [
    "# Run 'add' on the second set of 4 GPUs; the result should be [2, 3, 4, 5].\n",
    "output_1 = worker_group_1.execute_all_sync(\"add\", x=[2, 2, 2, 2])\n",
    "print(output_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "id": "d856d030",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([3.]), tensor([4.]), tensor([5.]), tensor([6.]), tensor([7.]), tensor([8.]), tensor([9.]), tensor([10.])]\n"
     ]
    }
   ],
   "source": [
    "# Run 'add' on the merged set of 8 GPUs; the result should be [3, 4, 5, 6, 7, 8, 9, 10].\n",
    "output_merge = worker_group_merge.execute_all_sync(\"add\", x=[3, 3, 3, 3, 3, 3, 3, 3])\n",
    "print(output_merge)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "33a4628c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4 4 8\n"
     ]
    }
   ],
   "source": [
    "print(worker_group.world_size, worker_group_1.world_size, worker_group_merge.world_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3df19d13",
   "metadata": {},
   "source": [
    "## Chapter 3: Data Dispatch, Execution and Collection"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acb22d9d",
   "metadata": {},
   "source": [
    "In the above example, we used the `execute_all_sync` function in the RayWorkerGroup to dispatch data from the driver to each worker. This is very inconvenient for coding. \n",
    "In this chapter, we use the form of function decorators to allow RayWorkerGroup to directly call functions written in the Worker, and to greatly simplify parameter passing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "id": "35237432",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from verl.single_controller.base.decorator import Dispatch, Execute, register"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "id": "88b8ba3b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "class GPUAccumulatorDecorator(Worker):\n",
    "    def __init__(self) -> None:\n",
    "        super().__init__()\n",
    "        # The initial value of each rank is the same as the rank\n",
    "        self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n",
    "\n",
    "    # map from a single input to all the worker\n",
    "    @register(Dispatch.ONE_TO_ALL)\n",
    "    def add(self, x):\n",
    "        print(x)\n",
    "        self.value = self.value + x\n",
    "        print(f\"rank {self.rank}, value: {self.value}\")\n",
    "        return self.value.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "id": "eddaa043",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class_with_args = RayClassWithInitArgs(cls=GPUAccumulatorDecorator)\n",
    "gpu_accumulator_decorator = RayWorkerGroup(resource_pool_merge, class_with_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "id": "10087c91",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[tensor([10.]), tensor([11.]), tensor([12.]), tensor([13.]), tensor([14.]), tensor([15.]), tensor([16.]), tensor([17.])]\n"
     ]
    }
   ],
   "source": [
    "# As we can see, 10 is automatically dispatched to each Worker in this RayWorkerGroup.\n",
    "print(gpu_accumulator_decorator.add(x=10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "540ee6ad",
   "metadata": {},
   "source": [
    "### Custom Dispatch, Collection\n",
    "Users can customize `dispatch` and `collection` function. You only need to write the `dispatch_fn` and `collect_fn` functions yourself. We also support executing RPC only on rank_zero, with specific examples provided below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "id": "8e041270",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from verl.single_controller.base.decorator import Dispatch, collect_all_to_all, register"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "43b5be31",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def two_to_all_dispatch_fn(worker_group, *args, **kwargs):\n",
    "    \"\"\"\n",
    "    Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.\n",
    "    \"\"\"\n",
    "    for arg in args:\n",
    "        assert len(arg) == 2\n",
    "        for i in range(worker_group.world_size - 2):\n",
    "            arg.append(arg[i % 2])\n",
    "    for k, v in kwargs.items():\n",
    "        assert len(v) == 2\n",
    "        for i in range(worker_group.world_size - 2):\n",
    "            v.append(v[i % 2])\n",
    "    return args, kwargs\n",
    "\n",
    "\n",
    "@ray.remote\n",
    "class TestActor(Worker):\n",
    "    # TODO: pass *args and **kwargs is bug prone and not very convincing\n",
    "    def __init__(self, x) -> None:\n",
    "        super().__init__()\n",
    "        self._x = x\n",
    "\n",
    "    def foo(self, y):\n",
    "        return self._x + y\n",
    "\n",
    "    @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\n",
    "    def foo_rank_zero(self, x, y):\n",
    "        return self._x + y + x\n",
    "\n",
    "    @register(dispatch_mode={\"dispatch_fn\": two_to_all_dispatch_fn, \"collect_fn\": collect_all_to_all})\n",
    "    def foo_custom(self, x, y):\n",
    "        return self._x + y + x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "id": "83ec6609",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\n",
    "worker_group = RayWorkerGroup(resource_pool, class_with_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "62c58d8a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])\n",
    "assert output_ref == [8, 10, 8, 10]\n",
    "\n",
    "output_ref = worker_group.foo_rank_zero(x=1, y=2)\n",
    "assert output_ref == 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "id": "14689353",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8\n"
     ]
    }
   ],
   "source": [
    "print(gpu_accumulator_decorator.world_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "id": "2c80bbf4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Shutdown ray cluster\n",
    "ray.shutdown()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5c8151c",
   "metadata": {},
   "source": [
    "## Chapter 4: NVMegatronRayWorkerGroup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd5680e9",
   "metadata": {},
   "source": [
    "Due to the Ray issue, we can only support max_colocate_count=1 in RayResourcePool for now. \n",
    "This means that each GPU can only have one process.\n",
    "We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92724419",
   "metadata": {},
   "source": [
    "Therefore, we need to restart the ray and initialize a new resource_pool to demonstrate the **NVMegatronRayWorkerGroup**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b038538",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Build a local ray cluster. The head node and worker node are on this machine\n",
    "ray.init()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebfd8798",
   "metadata": {},
   "source": [
    "Finally, we implement a `NVMegatronRayWorkerGroup`, within which we create a Megatron and then run a tensor parallel (tp) split Llama mlp layer. Here, we use a complex dispatch mode, `Megatron_COMPUTE`. This dispatch mode assumes that user passes the data partitioned by DP dimension. The data is dispatched to all tp/pp ranks within the same dp group, and ultimately only collects output data from tp=0 and the last pp. In this way, for users that only write code on the driver, the Megatron behind the RPC becomes transparent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "id": "5a032154",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/opt/tiger/Megatron-LM\n",
      "/opt/tiger/Megatron-LM/megatron/__init__.py\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "current_pythonpath = os.environ.get(\"PYTHONPATH\", \"\")\n",
    "\n",
    "new_path = \"/opt/tiger/Megatron-LM\"\n",
    "\n",
    "new_pythonpath = f\"{new_path}:{current_pythonpath}\" if current_pythonpath else new_path\n",
    "\n",
    "os.environ[\"PYTHONPATH\"] = new_pythonpath\n",
    "\n",
    "print(new_path)\n",
    "sys.path.append(new_path)\n",
    "\n",
    "import megatron\n",
    "\n",
    "print(megatron.__file__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "id": "8c84cd5a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from megatron.core import parallel_state as mpu\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from verl.single_controller.base.decorator import Dispatch, Execute, register\n",
    "from verl.single_controller.base.megatron.worker import MegatronWorker\n",
    "from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n",
    "from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "id": "1b1debcc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "resource_pool = RayResourcePool([4], use_gpu=True, max_colocate_count=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "id": "bccbe081",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "class MLPLayerWorker(MegatronWorker):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        rank = int(os.environ[\"LOCAL_RANK\"])\n",
    "        torch.distributed.init_process_group(backend=\"nccl\")\n",
    "        torch.cuda.set_device(rank)\n",
    "\n",
    "        mpu.initialize_model_parallel(\n",
    "            tensor_model_parallel_size=4,\n",
    "            pipeline_model_parallel_size=1,\n",
    "            virtual_pipeline_model_parallel_size=None,\n",
    "            pipeline_model_parallel_split_rank=None,\n",
    "            use_sharp=False,\n",
    "            context_parallel_size=1,\n",
    "            expert_model_parallel_size=1,\n",
    "            nccl_communicator_config_path=None,\n",
    "        )\n",
    "        from megatron.core import tensor_parallel\n",
    "\n",
    "        tensor_parallel.model_parallel_cuda_manual_seed(10)\n",
    "\n",
    "    @register(Dispatch.ONE_TO_ALL)\n",
    "    def init_model(self, config):\n",
    "        from omegaconf import OmegaConf\n",
    "\n",
    "        from verl.models.llama.megatron.layers import ParallelLlamaMLP\n",
    "        from verl.utils.megatron_utils import init_model_parallel_config\n",
    "\n",
    "        megatron_config = OmegaConf.create(\n",
    "            {\n",
    "                \"sequence_parallel\": False,\n",
    "                \"param_dtype\": \"fp32\",\n",
    "                \"tensor_model_parallel_size\": mpu.get_tensor_model_parallel_world_size(),\n",
    "                \"pipeline_model_parallel_rank\": mpu.get_pipeline_model_parallel_rank(),\n",
    "                \"pipeline_model_parallel_size\": mpu.get_pipeline_model_parallel_world_size(),\n",
    "                \"virtual_pipeline_model_parallel_rank\": mpu.get_virtual_pipeline_model_parallel_rank(),\n",
    "                \"virtual_pipeline_model_parallel_size\": mpu.get_virtual_pipeline_model_parallel_world_size(),\n",
    "            }\n",
    "        )\n",
    "\n",
    "        megatron_config = init_model_parallel_config(megatron_config)\n",
    "        self.parallel_layer = ParallelLlamaMLP(config=config, megatron_config=megatron_config)\n",
    "\n",
    "    @register(Dispatch.ONE_TO_ALL)\n",
    "    def get_weights(self):\n",
    "        output = {}\n",
    "        for key, val in self.parallel_layer.named_parameters():\n",
    "            output[key] = val\n",
    "        return output\n",
    "\n",
    "    @register(Dispatch.MEGATRON_COMPUTE)\n",
    "    def run_layer(self, x):\n",
    "        x = x.to(\"cuda\")\n",
    "        y = self.parallel_layer(x)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "id": "a655271d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "layer_cls = RayClassWithInitArgs(cls=MLPLayerWorker)\n",
    "layer_worker_group = NVMegatronRayWorkerGroup(\n",
    "    resource_pool=resource_pool,\n",
    "    ray_cls_with_init=layer_cls,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "id": "f105ebee",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4 4 1 1\n"
     ]
    }
   ],
   "source": [
    "print(layer_worker_group.world_size, layer_worker_group.tp_size, layer_worker_group.pp_size, layer_worker_group.dp_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "38655091",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ffn_hidden_size = 11008\n",
    "batch_size = 16\n",
    "seq_len = 2048\n",
    "hidden_size = 4096\n",
    "\n",
    "config = OmegaConf.create(\n",
    "    {\n",
    "        \"hidden_size\": hidden_size,\n",
    "        \"intermediate_size\": ffn_hidden_size,\n",
    "        \"hidden_act\": \"silu\",\n",
    "        \"pretraining_tp\": 1,\n",
    "        \"tp\": layer_worker_group.tp_size,\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "a026efca",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "x = torch.rand(size=(seq_len, batch_size, hidden_size), dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "id": "f5fcaf13",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[None, None, None, None]"
      ]
     },
     "execution_count": 179,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer_worker_group.init_model(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "id": "3f5cc9b4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2048, 16, 4096])\n"
     ]
    }
   ],
   "source": [
    "output = layer_worker_group.run_layer([x])  # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n",
    "print(output[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "id": "49792210",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Shutdown ray cluster\n",
    "ray.shutdown()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}