Spaces:
Runtime error
Runtime error
| from score_sde.models.projected_discriminator import ProjectedDiscriminator | |
| import torch | |
| discr = ProjectedDiscriminator(num_discs=4, backbone_kwargs={"cond_size": 768}) | |
| x = torch.randn(1,3,224,224) | |
| t = torch.randint(0, 1, size=(1,)) | |
| cond = (None, torch.randn(1,77, 768), torch.ones(1,77, dtype=torch.bool)) | |
| y = discr(x, t, x, cond=cond) | |
| print(y.shape) | |