Skip to content

Transformer is running with float32 instead of bfloat16 ! #1525

@githubsgi

Description

@githubsgi

Bug description

Modified the Llama3 modle.py to print dtype as follows and ran just 1 rank. The

    def forward(
        self,
        tokens: torch.Tensor,
        eos_id: int | None = None,
        input_batch: torch.Tensor | None = None,
    ):
        """
        Perform a forward pass through the Transformer model.

        Args:
            tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled.
                If pipeline parallelism is enabled, this will be the input token indices
                for the ranks on the first pipeline stage. This will be the activation of the
                previous pipeline stage if the current rank is not on the first stage.
            input_batch (torch.Tensor): The input batch read from the dataloader.
                This will always be the input batch regardless of the pipeline stage.
                This field is required for non-first PP stages to perform document
                masking attention (to analyze the boundary of the document).

        Returns:
            torch.Tensor: Output logits after applying the Transformer model.

        """
        if self.model_args.use_flex_attn:
            init_attention_mask(
                input_batch if input_batch is not None else tokens, eos_id=eos_id
            )

        print (f"tokens.dtype {tokens.dtype}")
        # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
        print (f"h.dtype {h.dtype}")

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)
            print (f"h.dtype {h.dtype}")

        h = self.norm(h) if self.norm else h
        print (f"h.dtype {h.dtype}")
        output = self.output(h) if self.output else h
        print (f"output.dtype {h.dtype}")
        return output

Seeing only float32 datatypes as follows.

tokens.dtype torch.int64
h.dtype torch.float32
h.dtype torch.float32
h.dtype torch.float32
h.dtype torch.float32
h.dtype torch.float32
h.dtype torch.float32
h.dtype torch.float32
h.dtype torch.float32
output.dtype torch.float32

The config is:

model.toml', 'dump_folder': './outputs', 'description': 'Llama 3 debug training', 'use_for_integration_test': True, 'print_args': True}, 'profiling': {'enable_profiling': False, 'save_traces_folder': 'profile_trace', 'profile_freq': 10, 'enable_memory_snapshot': False, 'save_memory_snapshot_folder': 'memory_snapshot'}, 'metrics': {'log_freq': 1, 'enable_tensorboard': False, 'disable_color_printing': False, 'save_tb_folder': 'tb', 'save_for_all_ranks': False, 'enable_wandb': False}, 'model': {'name': 'llama3', 'flavor': 'debugmodel', 'tokenizer_path': './tests/assets/tokenizer', 'converters': [], 'print_after_conversion': False}, 'optimizer': {'name': 'AdamW', 'lr': 0.0008, 'beta1': 0.9, 'beta2': 0.95, 'eps': 1e-08, 'weight_decay': 0.1, 'implementation': 'fused', 'early_step_in_backward': False}, 'lr_scheduler': {'warmup_steps': 2, 'decay_ratio': 0.8, 'decay_type': 'linear', 'min_lr_factor': 0.0}, 'training': {'dataset': 'c4_test', 'dataset_path': None, 'local_batch_size': 8, 'global_batch_size': -1, 'seq_len': 2048, 'max_norm': 1.0, 'steps': 10, 'enable_cpu_offload': False, 'mixed_precision_param': 'bfloat16', 'mixed_precision_reduce': 'float32', 'compile': False, 'gc_freq': 50, 'gc_debug': False, 'seed': None, 'deterministic': False}, 'parallelism': {'data_parallel_replicate_degree': 1, 'enable_compiled_autograd': False, 'data_parallel_shard_degree': -1, 'fsdp_reshard_after_forward': 'default', 'tensor_parallel_degree': 1, 'disable_loss_parallel': False, 'enable_async_tensor_parallel': False, 'pipeline_parallel_degree': 1, 'pipeline_parallel_split_points': [], 'module_fqns_per_model_part': None, 'pipeline_parallel_first_stage_less_layers': 1, 'pipeline_parallel_last_stage_less_layers': 1, 'pipeline_parallel_layers_per_stage': None, 'pipeline_parallel_schedule': '1F1B', 'pipeline_parallel_schedule_csv': '', 'pipeline_parallel_microbatch_size': 1, 'context_parallel_degree': 1, 'context_parallel_rotate_method': 'allgather', 'expert_parallel_degree': 1}, 'checkpoint': {'enable_checkpoint': False, 'folder': 'checkpoint', 'interval': 10, 'initial_load_path': None, 'initial_load_model_only': True, 'initial_load_in_hf': False, 'last_save_model_only': False, 'last_save_in_hf': False, 'export_dtype': 'float32', 'async_mode': 'disabled', 'keep_latest_k': 10, 'load_step': -1, 'exclude_from_loading': [], 'enable_first_step_checkpoint': False, 'create_seed_checkpoint': False}, 'activation_checkpoint': {'mode': 'selective', 'selective_ac_option': '2', 'per_op_sac_force_recompute_mm_shapes_by_fqns': ['moe.router.gate']}, 'float8': {'enable_fsdp_float8_all_gather': False, 'precompute_float8_dynamic_scale_for_fsdp': False, 'recipe_name': None, 'filter_fqns': ['output'], 'emulate': False, 'moe_fqns_prototype': []}, 'mx': {'mxfp8_dim1_cast_kernel_choice': 'triton', 'recipe_name': 'mxfp8_cublas', 'filter_fqns': ['output'], 'moe_fqns_prototype': []}, 'comm': {'init_timeout_seconds': 300, 'train_timeout_seconds': 100, 'trace_buf_size': 20000, 'save_traces_folder': 'comm_traces'}, 'memory_estimation': {'enabled': False, 'disable_fake_mode': False}, 'fault_tolerance': {'enable': False, 'process_group': 'gloo', 'process_group_timeout_ms': 10000, 'replica_id': 0, 'group_size': 0, 'min_replica_size': 1, 'semi_sync_method': None, 'sync_steps': 5, 'should_quantize': False, 'fragment_sync_delay': 0, 'fragment_update_alpha': 0.0}, 'experimental': {'custom_import': '', 'custom_args_module': ''}, 'validation': {'enabled': False, 'dataset': 'c4_validation', 'dataset_path': None, 'local_batch_size': 8, 'seq_len': 2048, 'freq': 5, 'steps': 10}}

Versions

TorchTitan master.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions