diff --git a/session/smollm2-speech-semantics/llama-tts.ipynb b/session/smollm2-speech-semantics/llama-tts.ipynb new file mode 100644 index 0000000..9559baa --- /dev/null +++ b/session/smollm2-speech-semantics/llama-tts.ipynb @@ -0,0 +1,2930 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "77ab21bc", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '1'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c9c65f61", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import copy\n", + "import numpy as np\n", + "from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig\n", + "from transformers import AddedToken\n", + "\n", + "IGNORE_INDEX = -100" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f2e0570c", + "metadata": {}, + "outputs": [], + "source": [ + "def lengths_to_padding_mask(lens):\n", + " bsz, max_lens = lens.size(0), torch.max(lens).item()\n", + " mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)\n", + " mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)\n", + " return mask\n", + "\n", + "\n", + "def _uniform_assignment(src_lens, tgt_lens):\n", + " tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device)\n", + " ratio = tgt_lens / src_lens\n", + " index_t = (tgt_indices / ratio.view(-1, 1)).long()\n", + " return index_t\n", + "\n", + "class SpeechGeneratorCTC(torch.nn.Module):\n", + " def __init__(self, config, ctc_upsample_factor = 26, unit_vocab_size = 1024):\n", + " super().__init__()\n", + " n_layers, n_dims, n_heads, n_inter_dims = 2,4096,32,11008\n", + " _config = copy.deepcopy(config)\n", + " _config.hidden_size = n_dims\n", + " _config.num_hidden_layers = n_layers\n", + " _config.num_attention_heads = n_heads\n", + " _config.num_key_value_heads = n_heads\n", + " _config.intermediate_size = n_inter_dims\n", + " _config._attn_implementation = \"flash_attention_2\"\n", + " self.upsample_factor = ctc_upsample_factor\n", + " self.input_proj = nn.Linear(config.hidden_size, n_dims)\n", + " self.layers = nn.ModuleList(\n", + " [LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)]\n", + " )\n", + " self.unit_vocab_size = unit_vocab_size\n", + " self.output_proj = nn.Linear(n_dims, self.unit_vocab_size + 1)\n", + " \n", + " def upsample(self, reps, tgt_units=None):\n", + " src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device)\n", + " up_lens = src_lens * self.upsample_factor\n", + " if tgt_units is not None:\n", + " tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)\n", + " up_lens = torch.max(up_lens, tgt_lens)\n", + " reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True)\n", + " padding_mask = lengths_to_padding_mask(up_lens)\n", + " mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill(\n", + " padding_mask, 0\n", + " )\n", + " copied_reps = torch.gather(\n", + " reps,\n", + " 1,\n", + " mapped_inputs.unsqueeze(-1).expand(\n", + " *mapped_inputs.size(), reps.size(-1)\n", + " ),\n", + " )\n", + " copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0)\n", + " position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device)\n", + " return copied_reps, ~padding_mask, position_ids\n", + "\n", + " def forward(self, tgt_reps, labels, tgt_units):\n", + " tgt_label_reps = []\n", + " for tgt_rep, label in zip(tgt_reps, labels):\n", + " tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])\n", + " hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units)\n", + " hidden_states = self.input_proj(hidden_states)\n", + " for layer in self.layers:\n", + " layer_outputs = layer(\n", + " hidden_states,\n", + " attention_mask=attention_mask,\n", + " position_ids=position_ids,\n", + " )\n", + " hidden_states = layer_outputs[0]\n", + " ctc_logits = self.output_proj(hidden_states)\n", + " ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)\n", + " ctc_lens = attention_mask.long().sum(dim=-1)\n", + " ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)\n", + " ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens)\n", + " ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask)\n", + " ctc_loss = F.ctc_loss(\n", + " ctc_lprobs.transpose(0, 1),\n", + " ctc_tgt_flat,\n", + " ctc_lens,\n", + " ctc_tgt_lens,\n", + " reduction=\"sum\",\n", + " zero_infinity=True,\n", + " blank=self.unit_vocab_size\n", + " )\n", + " ctc_loss /= ctc_tgt_lens.sum().item()\n", + " return ctc_loss\n", + " \n", + " def predict(self, tgt_reps):\n", + " hidden_states, attention_mask, position_ids = self.upsample([tgt_reps])\n", + " hidden_states = self.input_proj(hidden_states)\n", + " for layer in self.layers:\n", + " layer_outputs = layer(\n", + " hidden_states,\n", + " attention_mask=attention_mask,\n", + " position_ids=position_ids,\n", + " )\n", + " hidden_states = layer_outputs[0]\n", + " ctc_logits = self.output_proj(hidden_states)\n", + " ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)\n", + " ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size)\n", + " return ctc_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "85ac0866", + "metadata": {}, + "outputs": [], + "source": [ + "class LlamaTTS(LlamaForCausalLM):\n", + " def __init__(self, config):\n", + " super().__init__(config)\n", + " self.speech_generator = SpeechGeneratorCTC(self.config)\n", + " \n", + " def forward(self, tgt_units = None, **kwargs):\n", + " return super().forward(**kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a67b7345", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of LlamaTTS were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-135M-Instruct and are newly initialized: ['speech_generator.input_proj.bias', 'speech_generator.input_proj.weight', 'speech_generator.layers.0.input_layernorm.weight', 'speech_generator.layers.0.mlp.down_proj.weight', 'speech_generator.layers.0.mlp.gate_proj.weight', 'speech_generator.layers.0.mlp.up_proj.weight', 'speech_generator.layers.0.post_attention_layernorm.weight', 'speech_generator.layers.0.self_attn.k_proj.weight', 'speech_generator.layers.0.self_attn.o_proj.weight', 'speech_generator.layers.0.self_attn.q_proj.weight', 'speech_generator.layers.0.self_attn.v_proj.weight', 'speech_generator.layers.1.input_layernorm.weight', 'speech_generator.layers.1.mlp.down_proj.weight', 'speech_generator.layers.1.mlp.gate_proj.weight', 'speech_generator.layers.1.mlp.up_proj.weight', 'speech_generator.layers.1.post_attention_layernorm.weight', 'speech_generator.layers.1.self_attn.k_proj.weight', 'speech_generator.layers.1.self_attn.o_proj.weight', 'speech_generator.layers.1.self_attn.q_proj.weight', 'speech_generator.layers.1.self_attn.v_proj.weight', 'speech_generator.output_proj.bias', 'speech_generator.output_proj.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = LlamaTTS.from_pretrained('HuggingFaceTB/SmolLM2-135M-Instruct',\n", + " torch_dtype = torch.bfloat16)\n", + "_ = model.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "69856f4c", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-135M-Instruct')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a48d1365", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "new = ['<|speaker|>']\n", + "new = [AddedToken(t) for t in new]\n", + "tokenizer.add_tokens(new)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b47f58c6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Embedding(49153, 576, padding_idx=2)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.resize_token_embeddings(len(tokenizer), mean_resizing=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9dd7b0af", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "360298" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_parquet('data/train-00000-of-00001.parquet').to_dict(orient = 'records')\n", + "len(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "24d98e6c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'transcription': 'Sedangkan dalam bahasa Perancis , frira hanya bererti menggoreng di dalam minyak goreng yang banyak hingga terendam .',\n", + " 'speaker': 'Osman',\n", + " 'speaker_id': 1,\n", + " 'gender': 'male',\n", + " 'utterance_pitch_mean': 140.82264709472656,\n", + " 'utterance_pitch_std': 37.72042465209961,\n", + " 'snr': 69.54813385009766,\n", + " 'c50': 55.92512130737305,\n", + " 'speech_duration': 6.648750000000001,\n", + " 'stoi': 0.9943549633026123,\n", + " 'si-sdr': 16.59736442565918,\n", + " 'pesq': 3.5911829471588135,\n", + " 'pitch': 'slightly high pitch',\n", + " 'speaking_rate': 'very slowly',\n", + " 'noise': 'very clear',\n", + " 'reverberation': 'very confined sounding',\n", + " 'speech_monotony': 'very monotone',\n", + " 'prompt': 'Osman, a male speaker with a moderately high-pitched voice delivers an animated and expressive speech in a confined room with very clear recording. His voice is very monotone, and he speaks very slowly.',\n", + " 'audio_filename': 'combine-audio/0.mp3'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "row = df[0]\n", + "row" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "485769b1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': tensor([[49152, 11062, 1483, 49152, 67, 277, 604, 24184, 287, 20462,\n", + " 278, 1287, 14852, 3017, 1148, 271, 3297, 1669, 317, 294,\n", + " 37430, 39802, 9573, 89, 1800, 1057, 390, 863, 801, 287,\n", + " 20462, 1079, 105, 494, 310, 390, 863, 33856, 278, 1111,\n", + " 494, 27427, 8662, 252, 518, 268, 332, 1673]],\n", + " device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],\n", + " device='cuda:0')}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "speaker = f\"<|speaker|>{row['speaker']}<|speaker|>\"\n", + "len_speaker_token = len(tokenizer.tokenize(speaker))\n", + "prompt = f\"{speaker}{row['transcription']}\"\n", + "input_ids = tokenizer(prompt, add_special_tokens = False, return_tensors = 'pt').to('cuda')\n", + "input_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b7ba7568", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1224,)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "splitted = row['audio_filename'].split('/')\n", + "new_f = '/'.join([splitted[0] + '_vqgan'] + splitted[1:]).replace('.mp3', '.npy')\n", + "speech_token = np.load(new_f)\n", + "speech_token.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "9c1c63e7", + "metadata": {}, + "outputs": [], + "source": [ + "tgt_units = torch.tensor([speech_token]).to('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "6f01b389", + "metadata": {}, + "outputs": [], + "source": [ + "o = model(**input_ids, output_hidden_states = True, tgt_units = tgt_units)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "469e7acb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CausalLMOutputWithPast(loss=None, logits=tensor([[[ 15.6250, 6.7812, 10.1875, ..., 11.5625, 3.0938, -0.6523],\n", + " [ 18.3750, 11.1875, 15.3750, ..., 15.6250, 8.5000, -0.2578],\n", + " [ 4.6250, -18.0000, -15.7500, ..., -3.1406, -14.8125, 0.6172],\n", + " ...,\n", + " [ 19.3750, -0.9219, -0.0552, ..., 12.3125, 9.8125, -0.8047],\n", + " [ 9.8125, -5.8125, -5.3750, ..., 8.1875, -2.8594, -1.0000],\n", + " [ 8.1250, 5.7188, 10.0000, ..., 7.8750, -1.6172, 1.0859]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), past_key_values=((tensor([[[[ 0.4355, -0.5117, -0.0317, ..., -0.2715, -0.0140, 0.4023],\n", + " [-0.9375, 0.1650, -0.4316, ..., -0.6445, 2.7656, 0.1089],\n", + " [ 1.8750, 0.3926, -0.4883, ..., -0.3770, 2.1250, -0.0923],\n", + " ...,\n", + " [-0.6406, 0.2910, 0.2988, ..., -0.5352, 2.2969, 0.1279],\n", + " [ 1.0312, 0.2490, 0.3945, ..., 0.1216, 1.6641, 0.2891],\n", + " [ 1.1172, -0.3848, 0.4121, ..., 0.7500, 0.6484, -0.1309]],\n", + "\n", + " [[ 0.0298, 0.8555, -0.3477, ..., 0.5469, 0.1504, -0.0227],\n", + " [ 1.4062, 0.3320, -0.4023, ..., -0.5234, 0.4785, 0.7031],\n", + " [ 1.8906, 1.1250, -1.4609, ..., -0.3730, 0.3984, 1.4844],\n", + " ...,\n", + " [ 1.5469, -0.5820, -0.6797, ..., -0.0286, 0.2969, 1.0234],\n", + " [ 1.2344, 0.1328, 0.4160, ..., 0.2812, 0.4922, 1.5859],\n", + " [ 1.2812, 2.0781, 1.3828, ..., -0.4707, -0.3066, 0.5586]],\n", + "\n", + " [[-0.0074, -0.2285, 0.1348, ..., -0.2246, 0.1514, -0.1523],\n", + " [ 0.8359, 0.1992, -0.3320, ..., 0.5273, 0.5469, 0.4102],\n", + " [ 1.3672, 0.6602, 0.0391, ..., -0.3320, 0.4238, -0.5039],\n", + " ...,\n", + " [ 1.0781, -0.2578, 0.2734, ..., -0.0527, 0.7461, -0.2070],\n", + " [ 0.3555, 0.6328, -0.6641, ..., -0.9023, 0.7227, -1.0312],\n", + " [-0.0107, 0.6445, -0.2852, ..., -0.8125, 0.0859, -0.7227]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[-0.0126, -0.0204, -0.0112, ..., -0.0161, -0.0079, 0.0006],\n", + " [ 0.0386, -0.0391, -0.0015, ..., -0.0003, -0.0175, -0.0157],\n", + " [ 0.0425, -0.0238, -0.0193, ..., 0.0330, -0.0011, 0.0204],\n", + " ...,\n", + " [ 0.0131, 0.0064, -0.0208, ..., -0.0176, -0.0066, -0.0466],\n", + " [ 0.0236, -0.0082, 0.0281, ..., -0.0050, -0.0569, 0.0062],\n", + " [ 0.0327, 0.0115, 0.0352, ..., -0.0058, -0.0134, 0.0105]],\n", + "\n", + " [[-0.0413, 0.0150, -0.0825, ..., -0.0193, 0.0168, -0.0184],\n", + " [-0.0189, 0.0510, 0.0320, ..., 0.0132, -0.0361, 0.0525],\n", + " [-0.0103, -0.0195, 0.0620, ..., 0.0082, -0.0264, 0.0080],\n", + " ...,\n", + " [ 0.0042, 0.0532, 0.0054, ..., 0.0260, -0.0227, 0.0120],\n", + " [ 0.0452, 0.0325, 0.0894, ..., -0.0339, -0.0718, 0.0339],\n", + " [-0.0309, 0.0540, 0.0221, ..., 0.0101, -0.0179, 0.0830]],\n", + "\n", + " [[ 0.0070, -0.0378, -0.0299, ..., 0.0327, -0.0408, -0.0479],\n", + " [ 0.0004, -0.0195, 0.0488, ..., -0.0312, 0.0194, -0.0317],\n", + " [ 0.0107, 0.0053, 0.0228, ..., -0.0166, 0.0371, 0.0087],\n", + " ...,\n", + " [-0.0330, -0.0128, 0.0305, ..., -0.0222, 0.0120, 0.0283],\n", + " [ 0.1177, -0.0029, -0.0432, ..., -0.0894, -0.0013, 0.0364],\n", + " [ 0.0243, -0.0361, -0.0286, ..., 0.0020, -0.0056, 0.0090]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)), (tensor([[[[ 4.0312e+00, 1.2109e+00, -4.0000e+00, ..., -1.9844e+00,\n", + " -1.3594e+00, -2.5000e+00],\n", + " [-4.5000e+00, -7.8516e-01, -2.8594e+00, ..., -2.8711e-01,\n", + " -1.0312e+00, -2.5195e-01],\n", + " [-5.8750e+00, -2.7656e+00, -2.3750e+00, ..., 4.4141e-01,\n", + " -8.6328e-01, -1.0547e+00],\n", + " ...,\n", + " [-5.6562e+00, -2.9102e-01, 3.2969e+00, ..., 5.5859e-01,\n", + " -1.7578e+00, -1.1328e+00],\n", + " [-5.1250e+00, -1.7656e+00, 3.4062e+00, ..., -8.3618e-03,\n", + " -1.5781e+00, -1.0938e+00],\n", + " [-2.7344e+00, -1.7578e+00, 8.3984e-01, ..., -2.1484e-01,\n", + " -4.9609e-01, -1.3750e+00]],\n", + "\n", + " [[ 3.9453e-01, -1.6406e+00, 1.0547e-01, ..., -1.4219e+00,\n", + " -1.0547e+00, -5.2344e-01],\n", + " [ 9.4531e-01, -1.1719e+00, 1.9922e-01, ..., -9.7266e-01,\n", + " 1.7853e-03, 9.4531e-01],\n", + " [ 1.3750e+00, -3.4375e-01, -2.0703e-01, ..., 5.8594e-01,\n", + " -7.5195e-02, -2.5586e-01],\n", + " ...,\n", + " [ 1.3906e+00, -1.5156e+00, -1.8359e-01, ..., 5.0000e-01,\n", + " -1.2266e+00, 5.6250e-01],\n", + " [ 7.6172e-01, -8.8281e-01, -4.3555e-01, ..., 5.5859e-01,\n", + " -4.4727e-01, 8.7109e-01],\n", + " [-2.7031e+00, 2.0938e+00, 2.0781e+00, ..., -5.9375e-01,\n", + " -1.1094e+00, 3.0273e-01]],\n", + "\n", + " [[ 3.2344e+00, -2.1250e+00, 1.8750e+00, ..., -1.7383e-01,\n", + " 3.2617e-01, 3.7695e-01],\n", + " [ 3.0469e-01, -1.4531e+00, 1.1797e+00, ..., -1.0391e+00,\n", + " 8.6328e-01, -3.9453e-01],\n", + " [-7.7344e-01, -8.6328e-01, 6.0156e-01, ..., 5.5469e-01,\n", + " 6.5234e-01, 1.3359e+00],\n", + " ...,\n", + " [ 1.5938e+00, -2.2344e+00, -1.7031e+00, ..., -2.0605e-01,\n", + " 1.3281e+00, 1.0010e-01],\n", + " [-8.2422e-01, -1.8438e+00, -3.9844e-01, ..., -1.1953e+00,\n", + " -1.4746e-01, -1.7090e-01],\n", + " [-1.7656e+00, -8.7891e-02, -5.8594e-01, ..., -7.3828e-01,\n", + " -2.6172e-01, -5.7031e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-2.6172e-01, -3.0664e-01, -2.2266e-01, ..., 3.0859e-01,\n", + " -1.9141e-01, 2.9297e-01],\n", + " [-1.4844e+00, 8.2520e-02, -1.8359e-01, ..., 1.6406e-01,\n", + " 3.2031e-01, -2.1729e-02],\n", + " [ 4.1797e-01, 2.3047e-01, -5.1953e-01, ..., 7.9102e-02,\n", + " 1.4832e-02, -1.0693e-01],\n", + " ...,\n", + " [-5.8594e-01, -2.4707e-01, 1.1084e-01, ..., 8.3203e-01,\n", + " -7.0801e-02, 3.1250e-01],\n", + " [-3.3789e-01, -8.2031e-02, 3.6523e-01, ..., 9.0234e-01,\n", + " -4.2969e-01, 2.8125e-01],\n", + " [-1.4062e-01, -1.0864e-02, 3.4570e-01, ..., -8.3160e-04,\n", + " -1.7624e-03, 4.5898e-01]],\n", + "\n", + " [[ 6.9824e-02, 3.4180e-01, -1.8848e-01, ..., -9.9609e-02,\n", + " 6.3281e-01, -1.9629e-01],\n", + " [ 5.0537e-02, -5.8838e-02, 1.7090e-01, ..., -2.6562e-01,\n", + " 7.0801e-02, 1.8555e-01],\n", + " [-8.5938e-02, 1.4551e-01, -8.9355e-02, ..., 1.7871e-01,\n", + " 3.7598e-02, -5.3516e-01],\n", + " ...,\n", + " [ 3.5156e-02, 2.1240e-02, -1.1621e-01, ..., -2.9883e-01,\n", + " -5.2979e-02, 1.8848e-01],\n", + " [-1.7676e-01, -2.8516e-01, -1.8945e-01, ..., -1.2158e-01,\n", + " 8.3984e-02, -3.0859e-01],\n", + " [-2.7734e-01, -7.0703e-01, 1.7480e-01, ..., 2.9688e-01,\n", + " 6.4844e-01, 3.7891e-01]],\n", + "\n", + " [[ 2.0020e-01, 1.1133e-01, 3.3984e-01, ..., 2.6733e-02,\n", + " -2.1875e-01, -3.6719e-01],\n", + " [ 2.7344e-01, -1.8945e-01, 1.0156e-01, ..., -1.1475e-01,\n", + " -3.0469e-01, 2.6562e-01],\n", + " [-5.1172e-01, 2.5586e-01, 2.2852e-01, ..., 4.9023e-01,\n", + " 1.9238e-01, -1.6016e-01],\n", + " ...,\n", + " [-4.3750e-01, -3.4570e-01, 3.7305e-01, ..., 2.6855e-02,\n", + " -1.9629e-01, -2.6953e-01],\n", + " [-3.8281e-01, -1.8555e-01, -4.0430e-01, ..., -5.4688e-01,\n", + " -1.3574e-01, -5.6250e-01],\n", + " [ 7.5684e-02, 1.9043e-01, 5.5908e-02, ..., -1.2500e-01,\n", + " 1.6309e-01, 2.1680e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 4.0000, -3.2812, 1.6328, ..., 1.8984, 6.0625, 1.8047],\n", + " [-0.0547, -1.8828, 2.8125, ..., 1.5703, 4.8750, 1.2109],\n", + " [-2.7969, -0.5078, 1.7266, ..., 1.6406, 5.1875, 1.5938],\n", + " ...,\n", + " [ 0.1328, -3.1406, -2.0938, ..., 1.2656, 4.8750, 0.1553],\n", + " [-3.9375, -2.2344, -2.6562, ..., 2.9531, 4.4688, 0.6172],\n", + " [-3.6719, 0.2969, -2.1406, ..., -0.2812, 5.3125, 1.8125]],\n", + "\n", + " [[ 0.1670, -0.1475, 0.6328, ..., -3.3906, 2.1719, -0.8945],\n", + " [ 0.8359, -0.3945, 1.4531, ..., -3.2969, 2.0156, -1.2578],\n", + " [ 0.3535, -0.9062, 0.2637, ..., -3.0312, 1.3984, -0.4355],\n", + " ...,\n", + " [ 3.4688, -2.8750, -2.4531, ..., -1.8438, -0.4180, -0.0845],\n", + " [ 3.0312, -0.5195, -2.5000, ..., -3.7656, 0.7266, -1.2031],\n", + " [ 0.6680, -0.6406, -0.0461, ..., -3.0469, 1.5469, 0.1152]],\n", + "\n", + " [[-0.9023, -0.9844, 0.3555, ..., 0.1123, -1.8828, -0.3984],\n", + " [-0.7773, 0.7617, -0.6484, ..., -0.6523, -1.1484, -0.2988],\n", + " [-0.5273, -0.0535, -0.4551, ..., -0.5352, -1.5000, 0.3633],\n", + " ...,\n", + " [-3.1250, -2.0469, 0.5156, ..., 1.7656, -2.3750, -1.5156],\n", + " [-0.4570, -0.8047, -1.0000, ..., 1.1250, 1.2109, -1.4219],\n", + " [ 0.6523, 0.4102, 0.4102, ..., 0.5312, -2.3750, 0.7969]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[-0.0757, -0.0396, -0.1055, ..., -0.0952, -0.0422, 0.0464],\n", + " [-0.1582, 0.3809, -0.0977, ..., 0.0295, -0.2793, 0.1816],\n", + " [-0.1553, -0.4355, -0.1982, ..., -0.4551, 0.4199, 0.1680],\n", + " ...,\n", + " [-0.0981, -0.3555, -0.1162, ..., -0.7422, 0.1904, -0.2715],\n", + " [ 0.1797, 0.4473, -0.2480, ..., -0.8086, -0.1235, 0.1494],\n", + " [ 0.1035, -0.1221, 0.3203, ..., -0.1934, -0.4492, -0.0986]],\n", + "\n", + " [[-0.0116, -0.4180, -0.0227, ..., 0.2520, -0.1826, -0.1299],\n", + " [-0.3457, -0.2373, -0.2715, ..., -0.1138, -0.5312, -0.2109],\n", + " [ 0.0356, 0.1631, 0.0796, ..., 0.1914, -0.1641, -0.0781],\n", + " ...,\n", + " [-0.7031, -0.3164, 0.6211, ..., 0.2148, -0.7227, -0.1885],\n", + " [ 0.1309, -0.5742, -0.0479, ..., 0.1187, -0.7500, -0.1934],\n", + " [-0.0535, -0.1216, 0.0253, ..., -0.1748, -0.0031, 0.0513]],\n", + "\n", + " [[-0.1128, -0.2148, -0.1523, ..., 0.2148, -0.2129, -0.0425],\n", + " [ 0.2910, -0.1904, 0.0972, ..., -0.0977, -0.2793, 0.4512],\n", + " [-0.0046, -0.0082, 0.1328, ..., -0.0991, -0.1543, -0.0461],\n", + " ...,\n", + " [ 0.8047, 0.3730, 0.5234, ..., -0.3965, 0.3867, 0.0393],\n", + " [-0.0154, -0.1021, -0.1367, ..., -0.0688, -0.3691, 0.1875],\n", + " [ 0.1045, 0.2520, -0.3340, ..., -0.0830, -0.0752, -0.0432]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)), (tensor([[[[ 0.1797, -0.1260, 0.1387, ..., 1.6719, 1.0312, 0.2207],\n", + " [ 0.2676, -0.0815, 0.0991, ..., 1.5625, 1.0391, 0.3652],\n", + " [-0.0161, -0.1943, 0.0211, ..., 1.6406, 1.2500, 0.5391],\n", + " ...,\n", + " [ 5.7500, -2.3594, -3.3438, ..., 1.3125, 0.4062, 1.2500],\n", + " [ 0.7812, -2.5938, -2.2344, ..., 0.5117, 1.1250, 1.3359],\n", + " [-2.8594, 0.1445, -3.0156, ..., 1.6719, -0.9297, 1.1406]],\n", + "\n", + " [[ 0.1055, 0.0106, -0.1069, ..., -0.7812, 0.5117, -0.8477],\n", + " [ 0.0801, -0.1001, 0.0508, ..., -0.8750, 0.6055, -0.6172],\n", + " [ 0.0359, 0.0166, 0.0723, ..., -0.4512, 0.5117, -0.6484],\n", + " ...,\n", + " [ 2.5469, 1.4766, 1.1328, ..., 1.7344, 0.5195, 2.3906],\n", + " [-0.6406, 3.3750, -0.9375, ..., 0.6094, 0.6133, 3.0000],\n", + " [-2.0469, 0.9609, -0.4004, ..., -0.1904, -0.2148, 0.3789]],\n", + "\n", + " [[-0.0234, -0.0483, -0.0554, ..., -1.5000, 0.2949, -1.6016],\n", + " [-0.0952, -0.0527, 0.0283, ..., -1.4375, 0.2373, -1.4297],\n", + " [-0.1680, -0.0378, 0.0496, ..., -1.6094, 0.3184, -1.6641],\n", + " ...,\n", + " [-0.5352, 0.5469, -1.2031, ..., 3.5625, 1.0234, -2.1250],\n", + " [-1.3516, -0.1992, -0.6094, ..., 3.5000, 0.3438, -1.9297],\n", + " [-0.3379, -0.5938, -0.4863, ..., 0.4766, 0.0125, -3.3594]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[ 4.4189e-02, -7.9346e-03, -1.1475e-01, ..., -8.0566e-02,\n", + " 6.2500e-02, -8.8867e-02],\n", + " [ 2.6172e-01, -1.3672e-01, -1.8750e-01, ..., -3.8330e-02,\n", + " -9.6680e-02, -4.8096e-02],\n", + " [ 7.5684e-02, 1.7319e-03, 1.2741e-03, ..., -1.8066e-02,\n", + " 2.5269e-02, -5.6152e-02],\n", + " ...,\n", + " [-2.2500e+00, 9.7656e-01, 8.7891e-01, ..., -7.5391e-01,\n", + " 6.8750e-01, 2.3071e-02],\n", + " [ 8.0078e-01, -5.1953e-01, -3.8818e-02, ..., -2.8125e-01,\n", + " 1.4062e-01, -6.3477e-02],\n", + " [-3.1250e-01, 1.9922e-01, -5.9375e-01, ..., -2.1094e-01,\n", + " 5.5859e-01, 7.1777e-02]],\n", + "\n", + " [[ 9.5215e-02, -8.8379e-02, -5.0293e-02, ..., 3.5400e-02,\n", + " -7.4219e-02, -9.1309e-02],\n", + " [ 1.5320e-02, -9.0820e-02, -9.6680e-02, ..., -5.2002e-02,\n", + " -1.0010e-01, -2.5000e-01],\n", + " [ 9.9182e-04, 3.6926e-03, -9.0820e-02, ..., 2.9663e-02,\n", + " -9.4238e-02, -9.7168e-02],\n", + " ...,\n", + " [-4.8584e-02, -6.6016e-01, 1.9531e-01, ..., 1.9297e+00,\n", + " -2.6367e-01, 7.8906e-01],\n", + " [-2.6367e-01, -4.5703e-01, 2.3145e-01, ..., -1.7969e-01,\n", + " -2.2754e-01, 3.7109e-01],\n", + " [-5.6641e-01, -3.3789e-01, -3.3594e-01, ..., -1.7090e-01,\n", + " 5.7031e-01, 7.4609e-01]],\n", + "\n", + " [[-3.0060e-03, 2.9144e-03, 3.3691e-02, ..., 7.0801e-02,\n", + " 2.4780e-02, 7.5684e-02],\n", + " [ 9.4238e-02, 1.3086e-01, -3.2227e-02, ..., 1.4771e-02,\n", + " -7.1289e-02, 5.2979e-02],\n", + " [ 3.6377e-02, -6.2256e-02, 3.0518e-02, ..., -5.9814e-02,\n", + " 3.1738e-02, -5.4443e-02],\n", + " ...,\n", + " [ 7.8906e-01, 3.3447e-02, -5.4297e-01, ..., -7.0312e-01,\n", + " 4.5312e-01, 6.2109e-01],\n", + " [-1.3965e-01, -3.8818e-02, 8.0078e-01, ..., 2.7148e-01,\n", + " 9.0332e-02, 3.5400e-02],\n", + " [-2.7930e-01, -4.9561e-02, 5.4443e-02, ..., -2.5391e-01,\n", + " 1.2695e-01, 4.3750e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-0.1226, -0.0217, -0.0771, ..., 0.3887, -1.9062, 0.0957],\n", + " [-0.2012, -0.1992, -0.0332, ..., 0.5703, -1.7031, 0.0723],\n", + " [ 0.0405, -0.0947, -0.0645, ..., 0.4941, -1.7500, 0.1406],\n", + " ...,\n", + " [-2.8750, 0.4062, 1.9609, ..., 1.0000, -1.4297, 0.7773],\n", + " [ 0.6523, -0.8047, 0.9336, ..., 1.0703, -1.4453, 1.5547],\n", + " [ 0.6133, -0.6562, 1.6484, ..., 2.3594, -3.6875, 1.2188]],\n", + "\n", + " [[ 0.0270, -0.0610, -0.2471, ..., -0.4375, 0.4414, -0.7891],\n", + " [-0.0713, 0.0747, -0.0503, ..., -0.7773, 0.2441, -0.6250],\n", + " [-0.1406, 0.1328, -0.0200, ..., -0.2949, 0.3184, -0.6797],\n", + " ...,\n", + " [-2.7500, -1.5078, 2.5000, ..., -3.8906, -0.9570, -1.3281],\n", + " [-0.7656, 0.4609, 2.4219, ..., -0.0776, 2.0000, -0.0796],\n", + " [-0.3164, 4.0000, 1.6875, ..., -0.7461, 0.3691, -3.2031]],\n", + "\n", + " [[-0.1895, -0.1084, -0.1504, ..., 1.1094, 0.2148, -1.0859],\n", + " [-0.1768, -0.1396, -0.1621, ..., 1.0781, 0.3691, -0.8672],\n", + " [-0.1094, -0.0850, -0.1953, ..., 1.2266, 0.1641, -0.8945],\n", + " ...,\n", + " [-1.8125, -1.6094, 2.4531, ..., 0.3066, -0.1279, -0.5352],\n", + " [ 0.0312, -1.5078, 3.3750, ..., -0.3535, 0.4316, 0.4141],\n", + " [ 2.0000, 1.5312, 0.5703, ..., -2.0625, -0.6445, -2.0156]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[ 4.4678e-02, -9.6191e-02, -7.4707e-02, ..., 1.3962e-03,\n", + " 4.8584e-02, -3.7689e-03],\n", + " [-1.1414e-02, -4.3945e-02, 1.6113e-02, ..., -9.8877e-03,\n", + " 2.9053e-02, -6.0059e-02],\n", + " [-2.4292e-02, -2.8076e-02, -9.2773e-02, ..., -3.4485e-03,\n", + " -6.0303e-02, 7.1777e-02],\n", + " ...,\n", + " [-3.1445e-01, -8.5938e-01, -6.3672e-01, ..., 2.3438e-02,\n", + " 1.1279e-01, 2.4512e-01],\n", + " [-7.7637e-02, 7.4609e-01, 9.5312e-01, ..., -5.2344e-01,\n", + " -3.2422e-01, -3.6328e-01],\n", + " [ 9.4238e-02, -1.7480e-01, 4.3945e-01, ..., 1.2256e-01,\n", + " -3.8281e-01, 3.5938e-01]],\n", + "\n", + " [[ 3.8086e-02, 2.2949e-02, 1.4648e-02, ..., -9.4238e-02,\n", + " -9.1797e-02, -7.5684e-02],\n", + " [ 6.7383e-02, 5.6641e-02, 1.1914e-01, ..., -1.2598e-01,\n", + " 7.5195e-02, 4.0039e-02],\n", + " [ 5.7861e-02, 5.2734e-02, 6.7383e-02, ..., -1.2451e-01,\n", + " -3.7354e-02, 6.2500e-02],\n", + " ...,\n", + " [ 6.6406e-01, 6.2891e-01, -1.7480e-01, ..., 1.8433e-02,\n", + " 1.9609e+00, -4.4922e-01],\n", + " [ 1.2988e-01, 6.4453e-01, 1.6406e+00, ..., -2.6758e-01,\n", + " 1.0156e+00, 1.8438e+00],\n", + " [-4.4922e-01, -8.5156e-01, -4.6094e-01, ..., 4.6094e-01,\n", + " 1.1719e+00, 9.4727e-02]],\n", + "\n", + " [[ 4.9316e-02, 6.2012e-02, -5.2979e-02, ..., 5.8105e-02,\n", + " -5.8350e-02, 2.6733e-02],\n", + " [ 1.4453e-01, 6.9824e-02, -9.8633e-02, ..., 3.9062e-02,\n", + " -9.3750e-02, 3.1738e-02],\n", + " [ 4.9072e-02, 6.0303e-02, 3.7354e-02, ..., 6.1279e-02,\n", + " -5.1514e-02, 5.3467e-02],\n", + " ...,\n", + " [-6.5234e-01, -7.6953e-01, 5.2002e-02, ..., -7.8516e-01,\n", + " 3.1055e-01, -4.2969e-01],\n", + " [ 6.5625e-01, 1.9434e-01, 1.9043e-01, ..., -8.5449e-02,\n", + " -7.2656e-01, -6.6406e-01],\n", + " [-5.2344e-01, -5.4297e-01, 2.5586e-01, ..., 2.2461e-01,\n", + " -1.1377e-01, 6.8359e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-1.8359e-01, -1.8677e-02, -9.2773e-02, ..., -9.3750e-02,\n", + " -5.0391e-01, 1.3359e+00],\n", + " [-1.4453e-01, 1.0156e-01, -3.9551e-02, ..., -1.2305e-01,\n", + " -4.2969e-01, 1.3672e+00],\n", + " [ 6.2500e-02, 5.4932e-03, -1.1328e-01, ..., -1.8799e-02,\n", + " -3.3398e-01, 1.4297e+00],\n", + " ...,\n", + " [ 6.7188e-01, 6.1719e-01, 2.5781e+00, ..., -3.4375e-01,\n", + " -6.7188e-01, 5.6641e-01],\n", + " [ 2.8125e+00, 3.5938e-01, 2.3906e+00, ..., -2.0625e+00,\n", + " -5.3906e-01, 9.0820e-02],\n", + " [ 3.0273e-01, 4.6289e-01, 1.7383e-01, ..., 8.3008e-02,\n", + " -4.5703e-01, 6.5234e-01]],\n", + "\n", + " [[-1.0400e-01, -9.1309e-02, -1.4941e-01, ..., -5.9688e+00,\n", + " 7.1777e-02, -8.5547e-01],\n", + " [ 1.3574e-01, -1.9727e-01, -1.2207e-01, ..., -6.0938e+00,\n", + " -1.6895e-01, -9.5703e-01],\n", + " [ 2.6562e-01, -1.5137e-01, -1.2305e-01, ..., -5.9062e+00,\n", + " 3.6621e-03, -6.0547e-01],\n", + " ...,\n", + " [ 1.7969e+00, -1.3828e+00, 1.2656e+00, ..., -4.7812e+00,\n", + " 5.2812e+00, -3.1250e+00],\n", + " [ 1.1406e+00, -1.6992e-01, -5.8594e-01, ..., -3.4688e+00,\n", + " 4.2188e-01, 7.3828e-01],\n", + " [ 2.0312e+00, -1.0703e+00, 1.0547e+00, ..., -2.8594e+00,\n", + " -4.5898e-01, -2.9844e+00]],\n", + "\n", + " [[ 1.7676e-01, 4.6387e-02, -2.8125e-01, ..., 1.7188e+00,\n", + " 1.1523e-01, 4.8047e-01],\n", + " [ 1.9727e-01, 1.9336e-01, -3.5352e-01, ..., 1.6719e+00,\n", + " 3.9062e-01, 4.2578e-01],\n", + " [ 1.2695e-01, 2.0410e-01, -2.5391e-01, ..., 1.8438e+00,\n", + " 2.4707e-01, 5.1562e-01],\n", + " ...,\n", + " [ 2.2500e+00, -9.2285e-02, 1.8438e+00, ..., 1.1182e-01,\n", + " -1.2031e+00, 1.7969e+00],\n", + " [-3.9648e-01, 2.4531e+00, 2.2969e+00, ..., -6.1328e-01,\n", + " 5.1953e-01, 6.9922e-01],\n", + " [-8.6719e-01, 7.1875e-01, 1.0312e+00, ..., 1.0645e-01,\n", + " -6.3672e-01, -3.9062e-03]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 1.2158e-01, 2.9785e-02, -2.8809e-02, ..., -2.2095e-02,\n", + " -2.9541e-02, 4.4434e-02],\n", + " [ 1.1865e-01, 1.0400e-01, 5.4016e-03, ..., 2.5635e-02,\n", + " -7.9590e-02, 6.7871e-02],\n", + " [ 6.8848e-02, 1.0254e-01, -3.9307e-02, ..., -1.1414e-02,\n", + " -1.6479e-02, 3.0884e-02],\n", + " ...,\n", + " [ 5.7031e-01, -9.2188e-01, -1.3965e-01, ..., -4.2188e-01,\n", + " -1.0596e-01, 8.9844e-01],\n", + " [ 3.3203e-01, -3.2812e-01, -6.4062e-01, ..., 3.2617e-01,\n", + " 2.7734e-01, -5.3125e-01],\n", + " [ 2.8516e-01, 3.8867e-01, -3.0078e-01, ..., 5.8984e-01,\n", + " -8.9844e-01, 2.0410e-01]],\n", + "\n", + " [[-5.2734e-02, -2.5513e-02, -1.1719e-01, ..., -5.5176e-02,\n", + " 2.0874e-02, -4.9561e-02],\n", + " [-1.0742e-01, 9.4727e-02, -4.3701e-02, ..., -5.3467e-02,\n", + " 3.7354e-02, -3.3691e-02],\n", + " [-8.2520e-02, 6.7871e-02, -5.0049e-02, ..., -1.9653e-02,\n", + " -6.7871e-02, 5.2246e-02],\n", + " ...,\n", + " [ 5.8350e-02, -5.6250e-01, -1.0986e-01, ..., 2.4414e-01,\n", + " -8.5938e-01, 3.7109e-01],\n", + " [ 4.4727e-01, -5.0781e-01, -1.2422e+00, ..., 3.1250e-01,\n", + " 2.3438e-01, 9.6875e-01],\n", + " [ 3.8867e-01, -3.8477e-01, 8.0078e-01, ..., 5.0391e-01,\n", + " -5.0391e-01, 7.8516e-01]],\n", + "\n", + " [[-9.3750e-02, 4.3213e-02, -3.9307e-02, ..., 1.5625e-01,\n", + " -1.9775e-02, -3.6621e-02],\n", + " [-3.2227e-02, 1.0303e-01, -4.1748e-02, ..., 5.1514e-02,\n", + " -4.1016e-02, 1.0193e-02],\n", + " [-6.0791e-02, 1.1523e-01, -7.8125e-02, ..., 1.3672e-01,\n", + " 9.0332e-03, -1.1063e-03],\n", + " ...,\n", + " [-3.9844e-01, 2.7148e-01, 4.0234e-01, ..., 1.3203e+00,\n", + " 5.8984e-01, 2.7344e-01],\n", + " [-7.0703e-01, 9.6875e-01, 3.1055e-01, ..., 5.3467e-02,\n", + " -5.6458e-03, -3.7109e-02],\n", + " [ 1.5527e-01, 3.4766e-01, -1.5820e-01, ..., 1.8164e-01,\n", + " -5.8984e-01, -8.9355e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-0.1167, 0.0728, 0.1079, ..., 0.1865, -0.8672, 0.3223],\n", + " [-0.0498, 0.1201, 0.2656, ..., 0.3242, -0.7930, 0.4297],\n", + " [ 0.0464, -0.0157, 0.2207, ..., 0.4805, -0.7422, 0.3027],\n", + " ...,\n", + " [-0.7969, 1.6484, -0.0051, ..., -1.1875, -0.9258, 1.9688],\n", + " [ 1.4688, 1.4062, -1.0156, ..., 0.9258, -2.2969, 1.1250],\n", + " [-0.5430, -0.9023, -0.5547, ..., -0.7812, -0.4199, 0.5781]],\n", + "\n", + " [[ 0.0527, -0.0864, -0.0840, ..., 0.5938, -0.0496, 1.2344],\n", + " [ 0.0728, -0.0420, -0.1270, ..., 0.6406, -0.0194, 1.1953],\n", + " [ 0.0957, 0.1006, -0.0425, ..., 0.7461, 0.1621, 1.2188],\n", + " ...,\n", + " [ 1.0000, 0.9297, 1.4219, ..., -0.4805, -1.9609, 0.8711],\n", + " [ 1.4375, 0.4668, 1.1016, ..., 1.2422, -0.6797, 1.0859],\n", + " [ 1.0859, 0.6172, 1.9219, ..., 1.0469, -1.4219, 2.5000]],\n", + "\n", + " [[ 0.0095, 0.0217, -0.0439, ..., -1.4453, -2.0312, -1.2422],\n", + " [ 0.0071, 0.0464, -0.1445, ..., -1.4531, -2.0469, -1.2422],\n", + " [-0.0898, 0.0654, -0.1406, ..., -1.3828, -2.0000, -1.3359],\n", + " ...,\n", + " [-0.5078, -0.7852, 1.9219, ..., -2.6406, -1.8359, -1.1094],\n", + " [-2.3125, 2.6406, 2.5625, ..., -2.9844, -2.3438, -2.1875],\n", + " [-1.8438, 0.9180, 0.3281, ..., -0.9609, -1.0625, -3.0781]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[-5.9082e-02, 6.2500e-02, -3.3188e-04, ..., 2.8687e-02,\n", + " 1.8845e-03, -7.4219e-02],\n", + " [-6.2500e-02, 1.5039e-01, 2.9663e-02, ..., 1.1841e-02,\n", + " -3.5400e-02, -1.0352e-01],\n", + " [-6.7871e-02, 4.9072e-02, -1.1841e-02, ..., -8.3496e-02,\n", + " 1.2256e-01, -6.4941e-02],\n", + " ...,\n", + " [ 4.9609e-01, 7.5781e-01, 1.6699e-01, ..., -9.8145e-02,\n", + " 5.7031e-01, 1.6504e-01],\n", + " [-1.2598e-01, 5.8594e-01, -8.5547e-01, ..., -3.3203e-02,\n", + " 1.9336e-01, -2.7148e-01],\n", + " [-2.8516e-01, -2.6367e-01, -6.9531e-01, ..., 8.9453e-01,\n", + " 6.3281e-01, 1.0625e+00]],\n", + "\n", + " [[ 1.5747e-02, -2.4170e-02, 3.4424e-02, ..., -6.3477e-02,\n", + " 8.5449e-02, 4.6875e-01],\n", + " [ 3.5889e-02, 3.1982e-02, 1.4496e-03, ..., -1.6406e-01,\n", + " 2.8564e-02, 4.4922e-01],\n", + " [ 3.5645e-02, -5.7129e-02, -6.1279e-02, ..., -4.2725e-02,\n", + " 3.3691e-02, 4.5508e-01],\n", + " ...,\n", + " [-4.1016e-01, 5.0781e-01, -7.1484e-01, ..., -1.6602e-01,\n", + " -1.2266e+00, -4.7852e-01],\n", + " [ 7.5391e-01, 4.9414e-01, 6.7188e-01, ..., 3.9258e-01,\n", + " -3.0078e-01, -2.1094e+00],\n", + " [-1.1768e-01, -3.5400e-02, 7.2937e-03, ..., 1.4844e-01,\n", + " -4.8242e-01, -6.0938e-01]],\n", + "\n", + " [[-8.6426e-02, -3.3691e-02, 8.2031e-02, ..., -3.5645e-02,\n", + " 5.8984e-01, 7.8125e-02],\n", + " [-5.0781e-02, -3.8574e-02, 8.6426e-02, ..., -1.0254e-01,\n", + " 5.3906e-01, -1.2512e-02],\n", + " [-7.8613e-02, 2.9175e-02, 1.4453e-01, ..., -4.1016e-02,\n", + " 5.8984e-01, 3.3447e-02],\n", + " ...,\n", + " [ 1.5000e+00, -1.1172e+00, -3.2617e-01, ..., 9.6484e-01,\n", + " -1.1172e+00, -7.2656e-01],\n", + " [ 1.1250e+00, -6.9824e-02, 1.1016e+00, ..., 6.3672e-01,\n", + " -6.6406e-01, -1.1016e+00],\n", + " [-2.0508e-01, -1.0312e+00, 3.0273e-01, ..., 5.6396e-02,\n", + " -1.2031e+00, -3.3398e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 0.0649, 0.0364, 0.0082, ..., 0.7383, -2.2969, 0.3027],\n", + " [ 0.1211, -0.3398, 0.1445, ..., 0.7578, -2.3438, 0.0688],\n", + " [ 0.0347, -0.2344, 0.0747, ..., 0.6758, -2.2812, 0.1553],\n", + " ...,\n", + " [ 0.8750, -0.0620, -0.5391, ..., 0.9922, -1.4219, -1.5547],\n", + " [-1.3438, -2.3750, -2.2969, ..., 0.4121, -1.1094, -2.1562],\n", + " [-0.4766, -0.7344, -0.0781, ..., 0.2021, 0.6680, 1.3594]],\n", + "\n", + " [[ 0.0957, 0.0723, 0.0260, ..., 0.8555, -2.3125, -1.6328],\n", + " [ 0.1797, -0.0933, -0.0156, ..., 0.6875, -2.1875, -1.7109],\n", + " [-0.0659, -0.2139, -0.0261, ..., 0.8008, -2.3125, -1.6797],\n", + " ...,\n", + " [ 1.3125, -0.3691, 0.3574, ..., 0.2695, -4.2188, -3.2188],\n", + " [-0.4062, -1.0859, 0.0742, ..., -1.0234, -3.6875, -2.4062],\n", + " [-0.4375, 0.5859, 1.1562, ..., 4.2500, -3.6250, -3.1562]],\n", + "\n", + " [[ 0.1455, -0.1187, 0.1455, ..., 0.5508, -1.1797, 2.0938],\n", + " [ 0.2988, 0.0669, 0.2734, ..., 1.2031, -0.5352, 2.9219],\n", + " [ 0.0000, 0.0737, 0.2578, ..., 0.5117, -0.6523, 2.5312],\n", + " ...,\n", + " [ 1.3203, -1.3984, -0.6602, ..., -1.1484, 7.8750, -7.9688],\n", + " [-1.1953, 0.2812, -1.7578, ..., 1.1250, 8.0000, -2.6094],\n", + " [-0.9570, 0.8008, -0.4043, ..., 2.7188, 5.1250, 0.5586]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[ 5.4199e-02, 2.7588e-02, -1.8848e-01, ..., -1.5137e-01,\n", + " -1.5488e-03, -5.3467e-02],\n", + " [-1.1865e-01, -1.8652e-01, -8.0566e-02, ..., -1.4453e-01,\n", + " -6.3477e-02, 4.1504e-02],\n", + " [-2.0264e-02, 3.0518e-04, -4.3457e-02, ..., -9.9121e-02,\n", + " -3.8818e-02, 2.0752e-02],\n", + " ...,\n", + " [-3.0469e+00, -1.1016e+00, -4.6484e-01, ..., 2.6562e-01,\n", + " 4.9219e-01, 6.4844e-01],\n", + " [-2.7031e+00, -2.4062e+00, 9.2578e-01, ..., 7.0312e-01,\n", + " 1.2422e+00, -4.2188e-01],\n", + " [ 5.1562e-01, -1.6602e-01, 1.3184e-01, ..., 7.8906e-01,\n", + " -1.1328e+00, -1.7285e-01]],\n", + "\n", + " [[ 8.7891e-03, -2.0508e-02, 6.0059e-02, ..., -1.7285e-01,\n", + " 6.6895e-02, -1.1328e-01],\n", + " [-2.5391e-02, 7.3242e-02, 7.6660e-02, ..., -1.5430e-01,\n", + " 3.0664e-01, -4.1260e-02],\n", + " [ 1.3672e-01, -2.1240e-02, -1.4258e-01, ..., -9.4238e-02,\n", + " 3.2715e-02, -1.0547e-01],\n", + " ...,\n", + " [ 6.9141e-01, -4.6484e-01, -3.8818e-02, ..., 1.0107e-01,\n", + " 2.1562e+00, 7.5391e-01],\n", + " [ 2.3828e-01, -3.8477e-01, 7.7734e-01, ..., 9.0234e-01,\n", + " -1.3516e+00, -1.0469e+00],\n", + " [ 9.1309e-02, -3.4375e-01, 6.7188e-01, ..., 1.9062e+00,\n", + " 6.7383e-02, 1.4648e-01]],\n", + "\n", + " [[ 2.0599e-03, -9.5215e-02, -2.5635e-02, ..., 9.1797e-02,\n", + " -6.3965e-02, 4.9316e-02],\n", + " [-8.6060e-03, -1.8750e-01, -1.0132e-02, ..., 9.2285e-02,\n", + " -5.9570e-02, 4.1504e-02],\n", + " [ 7.3853e-03, -1.3770e-01, 2.5879e-02, ..., 9.3750e-02,\n", + " -1.3477e-01, 7.1777e-02],\n", + " ...,\n", + " [-1.0156e+00, 4.6680e-01, 1.7871e-01, ..., -4.1406e-01,\n", + " -1.1094e+00, -2.2949e-01],\n", + " [-3.4180e-01, 2.8320e-02, 2.8711e-01, ..., 3.2422e-01,\n", + " -2.8516e-01, -6.8848e-02],\n", + " [-2.5391e-01, 1.8158e-03, -3.4180e-02, ..., 3.5156e-01,\n", + " 6.2109e-01, -3.8672e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-9.3384e-03, 9.6191e-02, -4.7363e-02, ..., -4.7070e-01,\n", + " 2.6875e+00, 7.1094e-01],\n", + " [ 1.7773e-01, 1.3477e-01, -5.4932e-02, ..., -5.1953e-01,\n", + " 2.9062e+00, 8.3984e-01],\n", + " [ 1.5430e-01, 3.0762e-02, 2.0020e-02, ..., -7.3438e-01,\n", + " 2.8281e+00, 6.0938e-01],\n", + " ...,\n", + " [ 1.0781e+00, 1.3438e+00, 1.9922e+00, ..., -4.3555e-01,\n", + " -5.3750e+00, 1.0312e+00],\n", + " [ 1.1250e+00, 1.5312e+00, 1.5703e+00, ..., -3.7656e+00,\n", + " -6.7188e+00, 3.9375e+00],\n", + " [ 7.7344e-01, 9.5312e-01, 7.0312e-01, ..., 3.4375e-01,\n", + " 6.6016e-01, 3.3984e-01]],\n", + "\n", + " [[-3.0884e-02, -1.0303e-01, -1.7188e-01, ..., -1.8281e+00,\n", + " 3.2031e-01, 4.0820e-01],\n", + " [ 1.0840e-01, -7.2266e-02, -6.2500e-02, ..., -2.0625e+00,\n", + " 1.2988e-01, 3.6719e-01],\n", + " [ 1.5430e-01, 7.6172e-02, -7.1777e-02, ..., -2.2656e+00,\n", + " 1.5234e-01, 6.1719e-01],\n", + " ...,\n", + " [ 1.5781e+00, -1.0547e+00, 4.6250e+00, ..., -2.0156e+00,\n", + " 6.0547e-01, -3.9673e-03],\n", + " [ 4.8750e+00, -7.7344e-01, 2.5312e+00, ..., -2.7812e+00,\n", + " -8.9722e-03, -2.7734e-01],\n", + " [ 3.5645e-02, 9.8828e-01, 3.1641e-01, ..., -2.2188e+00,\n", + " 1.1172e+00, 1.3828e+00]],\n", + "\n", + " [[-2.1118e-02, 1.2988e-01, -4.0039e-02, ..., 1.7773e-01,\n", + " 1.1572e-01, 3.8477e-01],\n", + " [-3.1128e-02, 1.2207e-02, -1.9238e-01, ..., 1.3672e-01,\n", + " -1.1035e-01, 5.1562e-01],\n", + " [-1.5869e-03, -3.3203e-02, -1.2305e-01, ..., 5.1025e-02,\n", + " -1.5820e-01, 3.9648e-01],\n", + " ...,\n", + " [ 4.4141e-01, -1.6875e+00, 1.2969e+00, ..., -9.8828e-01,\n", + " -3.1250e-01, -5.3516e-01],\n", + " [ 1.1641e+00, -9.3750e-01, 1.6406e-01, ..., -1.8516e+00,\n", + " -6.5234e-01, 1.0938e+00],\n", + " [ 1.5234e-01, 5.9570e-02, 5.0000e-01, ..., 6.6406e-02,\n", + " 1.5469e+00, -1.8750e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-6.4941e-02, 6.4453e-02, -3.1250e-01, ..., -8.9844e-02,\n", + " 2.5195e-01, 1.6602e-02],\n", + " [-2.1387e-01, 8.6426e-02, -3.8281e-01, ..., -9.0332e-02,\n", + " 2.2070e-01, 1.1658e-02],\n", + " [-1.5625e-01, 4.7607e-02, -3.7305e-01, ..., -3.3203e-02,\n", + " 1.0156e-01, -8.8867e-02],\n", + " ...,\n", + " [-5.9375e-01, -3.4570e-01, 2.9844e+00, ..., 1.2500e+00,\n", + " -5.8203e-01, 1.6250e+00],\n", + " [-5.1562e-01, 1.1484e+00, 3.1406e+00, ..., -3.8086e-01,\n", + " -3.8672e-01, 9.6484e-01],\n", + " [ 8.6719e-01, 1.4844e-01, 1.3984e+00, ..., -1.6992e-01,\n", + " -1.1963e-01, 1.2734e+00]],\n", + "\n", + " [[ 2.6245e-02, 4.1992e-02, 7.5684e-02, ..., -9.9609e-02,\n", + " -2.2949e-02, -3.1128e-02],\n", + " [ 5.9326e-02, 4.0527e-02, 1.0352e-01, ..., -1.6895e-01,\n", + " -1.4746e-01, -1.5137e-01],\n", + " [ 8.6914e-02, -5.0049e-02, -6.2500e-02, ..., -9.1309e-02,\n", + " 2.3651e-03, -1.5137e-01],\n", + " ...,\n", + " [ 3.9648e-01, -5.7812e-01, -9.8047e-01, ..., -2.5977e-01,\n", + " 6.5234e-01, -5.1953e-01],\n", + " [ 1.1094e+00, -7.5000e-01, 6.4453e-02, ..., 3.4180e-01,\n", + " -8.1250e-01, 1.6504e-01],\n", + " [ 3.6133e-01, -5.6641e-01, -4.4336e-01, ..., -2.9883e-01,\n", + " 2.5781e-01, -5.1562e-01]],\n", + "\n", + " [[-9.5215e-03, -5.6885e-02, -9.1797e-02, ..., 6.0547e-02,\n", + " -9.1797e-02, -7.4463e-03],\n", + " [ 5.0537e-02, -3.0151e-02, -1.4648e-02, ..., 9.1797e-02,\n", + " -9.6191e-02, 3.6316e-03],\n", + " [ 1.3580e-03, -8.1543e-02, -5.6641e-02, ..., 2.0703e-01,\n", + " -8.5449e-02, -1.1328e-01],\n", + " ...,\n", + " [ 3.6523e-01, 7.3438e-01, 2.8906e-01, ..., -1.0312e+00,\n", + " 5.3516e-01, -2.3926e-01],\n", + " [ 3.8867e-01, 3.7109e-01, -1.6875e+00, ..., -1.0078e+00,\n", + " 8.8867e-02, -7.4609e-01],\n", + " [-3.8086e-02, 5.7812e-01, -5.0000e-01, ..., 8.7109e-01,\n", + " -5.3125e-01, 1.6211e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-0.2949, -0.4551, -0.1104, ..., 0.9375, -0.0190, 1.1094],\n", + " [ 0.1230, -0.5156, -0.3926, ..., 1.0156, -0.3359, 1.0078],\n", + " [ 0.5234, -0.1182, -0.3438, ..., 1.0312, -0.0913, 1.1719],\n", + " ...,\n", + " [-0.2930, -2.6406, -0.0522, ..., 2.4375, -0.9961, 2.6250],\n", + " [ 2.3438, -2.0781, 0.6797, ..., 2.7656, -0.8828, 1.2578],\n", + " [-0.3594, -0.0166, 0.6680, ..., 1.7031, -0.6992, 2.0312]],\n", + "\n", + " [[ 0.2139, -0.0610, 0.0205, ..., 0.8125, 0.4883, 0.0625],\n", + " [ 0.0259, 0.0747, 0.1289, ..., 0.9219, 0.3613, 0.0801],\n", + " [-0.2100, 0.1260, 0.0820, ..., 0.7812, 0.4434, 0.3652],\n", + " ...,\n", + " [ 2.0781, -2.2812, -1.8672, ..., 1.2188, -1.9531, -0.3594],\n", + " [-0.0234, -0.2578, -2.0625, ..., 1.0781, 1.6406, 0.3301],\n", + " [-1.1016, 0.2266, -1.9062, ..., 1.3594, 0.6484, -0.3125]],\n", + "\n", + " [[ 0.2207, -0.2266, -0.1426, ..., -0.6016, 0.1934, -1.2031],\n", + " [ 0.2637, -0.1299, -0.2207, ..., -0.3965, 0.1758, -1.1172],\n", + " [ 0.0610, 0.0413, -0.2930, ..., -0.3203, 0.2246, -1.2812],\n", + " ...,\n", + " [ 3.0000, -3.2344, -0.5312, ..., -2.0625, 1.1875, -0.2676],\n", + " [-1.3047, -2.9219, 0.1055, ..., -1.3047, 1.7109, -2.7656],\n", + " [-0.5898, -0.4258, 0.1699, ..., -2.3750, 0.5156, -2.5781]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[-6.1719e-01, -1.3977e-02, 2.4658e-02, ..., 9.8633e-02,\n", + " -6.7871e-02, -4.1992e-02],\n", + " [-5.1172e-01, 8.0078e-02, 4.2969e-02, ..., 4.1016e-02,\n", + " -7.7637e-02, -1.4648e-01],\n", + " [-6.2109e-01, 1.3489e-02, 2.4658e-02, ..., 1.0840e-01,\n", + " -6.6895e-02, -4.1199e-03],\n", + " ...,\n", + " [ 1.2188e+00, -2.1191e-01, 1.9434e-01, ..., 1.9336e-01,\n", + " -4.6680e-01, -2.4062e+00],\n", + " [ 6.2109e-01, 6.8359e-02, -3.9307e-02, ..., -6.2988e-02,\n", + " -4.4922e-01, -2.0938e+00],\n", + " [-2.1094e-01, 6.8359e-01, -1.6797e-01, ..., 5.9766e-01,\n", + " -1.0596e-01, 5.6250e-01]],\n", + "\n", + " [[ 8.3008e-02, 1.8945e-01, 3.7384e-04, ..., -9.8267e-03,\n", + " 7.6172e-02, -1.1133e-01],\n", + " [ 1.6504e-01, 1.0596e-01, -7.6172e-02, ..., 5.5908e-02,\n", + " 1.3672e-01, -1.2793e-01],\n", + " [ 1.8387e-03, 6.9824e-02, 1.6602e-02, ..., -6.6406e-02,\n", + " 1.0791e-01, -1.2402e-01],\n", + " ...,\n", + " [ 1.0391e+00, -3.1250e-01, -1.0498e-01, ..., -2.5625e+00,\n", + " 3.5156e-02, -1.8652e-01],\n", + " [ 3.7500e-01, 1.9141e-01, 4.4922e-01, ..., -8.8281e-01,\n", + " 2.3682e-02, 3.5742e-01],\n", + " [ 3.3936e-02, 2.7924e-03, -3.1250e-01, ..., 8.8672e-01,\n", + " 4.1260e-02, 7.9102e-02]],\n", + "\n", + " [[-7.2266e-02, -5.1758e-02, -3.5400e-02, ..., 1.1963e-01,\n", + " 8.5938e-02, -5.6641e-02],\n", + " [-6.0059e-02, -1.0596e-01, 3.2227e-02, ..., 1.8164e-01,\n", + " 6.4453e-02, -5.4688e-02],\n", + " [-9.3750e-02, -4.3213e-02, -1.0059e-01, ..., 5.0049e-02,\n", + " -6.8359e-02, -1.4160e-01],\n", + " ...,\n", + " [ 3.3789e-01, 2.7539e-01, -3.1250e-01, ..., -1.5234e+00,\n", + " 1.0312e+00, 1.5234e+00],\n", + " [-2.7539e-01, 1.8457e-01, -1.5078e+00, ..., -7.7637e-02,\n", + " 5.3125e-01, 2.0117e-01],\n", + " [ 2.4219e-01, 2.8516e-01, 3.8086e-01, ..., -4.4922e-01,\n", + " -1.4954e-02, -1.9836e-03]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 0.2031, 2.6094, 1.0859, ..., 0.9727, -1.7812, -1.4297],\n", + " [-2.6250, 2.8906, 0.0000, ..., 1.0234, -1.3906, -1.3906],\n", + " [-2.2500, 0.3203, -0.6016, ..., 1.7969, -2.4531, -1.5547],\n", + " ...,\n", + " [-3.0938, 2.8594, -4.0625, ..., -0.2432, -1.3672, 0.5469],\n", + " [-2.6250, 0.4922, -1.6250, ..., 1.1094, -0.8359, -2.0312],\n", + " [-0.0938, 0.7109, -0.0859, ..., 0.4277, -0.7734, -1.5547]],\n", + "\n", + " [[-0.3438, -0.3555, 0.5820, ..., 1.1328, -1.5312, 2.1719],\n", + " [-2.1250, -0.9297, -0.2422, ..., 1.0000, -1.3906, 2.0469],\n", + " [-0.8828, -0.6484, -0.2266, ..., 1.0000, -0.8789, 1.0156],\n", + " ...,\n", + " [-6.4688, 1.1641, 1.4922, ..., 2.9531, -1.1875, 2.0938],\n", + " [-4.3438, -0.6797, 3.0469, ..., 2.4531, -1.0625, 1.7266],\n", + " [-1.5547, -2.7188, 2.7656, ..., 2.5625, -1.4922, 3.5000]],\n", + "\n", + " [[-1.2109, -0.7812, 0.3574, ..., 2.3750, 1.6953, 0.5117],\n", + " [-1.6406, -2.0781, 0.0586, ..., 2.3438, 1.1016, 0.5234],\n", + " [ 0.2734, -0.7930, -0.4082, ..., 2.2812, 2.2031, 0.5430],\n", + " ...,\n", + " [-2.4688, -2.1406, -1.6953, ..., 1.1562, 1.4766, 1.8672],\n", + " [ 2.0938, -0.3711, -0.1562, ..., 1.1875, 1.7031, 2.2812],\n", + " [ 1.6016, 0.6641, 0.4199, ..., 1.2812, 2.0625, 1.4531]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[ 0.0874, 0.2656, 0.1973, ..., -0.2734, -0.2354, -0.2119],\n", + " [ 0.0554, 0.1885, 0.2480, ..., -0.8477, -0.3242, -0.7695],\n", + " [-0.0601, 0.0820, 0.4238, ..., -0.2412, -0.1709, -0.1426],\n", + " ...,\n", + " [ 0.7891, 0.3184, -0.6211, ..., -0.0527, 0.1963, 0.2852],\n", + " [-0.1504, -0.3281, 0.2080, ..., 0.0201, 0.8125, 0.2988],\n", + " [-0.4199, -0.0168, 0.0466, ..., 0.3008, 0.3672, -0.2383]],\n", + "\n", + " [[ 0.1650, -0.3770, 0.2773, ..., 0.2520, -0.3555, 0.1104],\n", + " [ 0.1738, -0.1367, 0.3281, ..., 0.0815, 0.2041, -0.2197],\n", + " [ 0.0559, -0.2178, 0.2637, ..., 0.1216, -0.1484, -0.0850],\n", + " ...,\n", + " [ 0.6914, 0.3730, -0.0215, ..., 0.2656, -0.3496, 0.0231],\n", + " [ 0.8398, -0.0184, -0.0649, ..., 0.8555, -0.3477, 0.1387],\n", + " [-0.0310, 0.1069, 0.2217, ..., -0.0640, -0.2275, 0.1875]],\n", + "\n", + " [[ 0.1729, -0.1797, 0.3418, ..., 0.2988, -0.5898, 0.1279],\n", + " [ 0.1602, -0.2891, 0.8555, ..., 0.2129, -0.6992, -0.6250],\n", + " [ 0.2197, 0.1670, 0.5469, ..., 0.0286, -0.5312, -0.0732],\n", + " ...,\n", + " [-0.2480, 0.4629, 0.0337, ..., 1.5625, 0.1973, -0.7773],\n", + " [-0.2988, -0.0571, -0.6094, ..., 0.4805, 0.3438, -0.0449],\n", + " [-0.1982, -0.2559, -0.3320, ..., -0.1006, -0.1777, -0.2217]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)), (tensor([[[[-4.9062, -3.5938, 4.4375, ..., -5.8750, 4.0938, -0.3086],\n", + " [ 1.8750, -3.0000, 3.0312, ..., -5.2188, 3.6250, -0.2832],\n", + " [ 5.8438, -1.6719, 1.8281, ..., -5.0625, 2.7656, -0.3320],\n", + " ...,\n", + " [ 2.6562, -1.7344, -3.0469, ..., -4.8750, 1.2969, -1.3359],\n", + " [ 5.1562, -1.5234, -2.1094, ..., -5.8438, 2.0312, -1.9453],\n", + " [ 3.4062, 0.1562, -1.8125, ..., -4.6250, 2.5000, -0.6172]],\n", + "\n", + " [[-0.9961, 2.2188, 1.8672, ..., -3.1250, -5.4688, 1.9766],\n", + " [-1.8750, 1.0078, 1.0703, ..., -1.7422, -5.0000, 1.1719],\n", + " [-0.2422, 0.8164, -0.1914, ..., -2.3750, -3.9219, 1.6094],\n", + " ...,\n", + " [-3.6562, -0.9727, -4.7500, ..., -3.8750, -4.7812, 2.4844],\n", + " [ 1.3281, -0.3770, -2.2656, ..., -3.2031, -4.1562, 3.7969],\n", + " [ 0.2969, -0.2949, 0.4375, ..., -3.2188, -4.1250, 2.7188]],\n", + "\n", + " [[-3.5156, 2.0000, -1.1562, ..., 1.8906, 8.2500, -0.4629],\n", + " [-2.3750, 0.6484, -0.0898, ..., 0.6641, 7.9688, -0.0356],\n", + " [ 0.9648, 0.0215, 0.8164, ..., 4.6875, 5.1562, -1.9766],\n", + " ...,\n", + " [-1.3750, 0.4219, 1.2266, ..., 4.1250, 5.8125, -2.6875],\n", + " [ 2.2500, 1.3359, -0.2617, ..., 3.9844, 6.2188, -2.5312],\n", + " [ 0.9648, -0.6172, -0.4863, ..., 3.2500, 5.8438, -2.0938]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[[ 2.2949e-01, 3.2715e-02, 1.1621e-01, ..., 8.2031e-01,\n", + " 1.4258e-01, -4.9219e-01],\n", + " [ 8.1055e-02, -2.0801e-01, 3.1055e-01, ..., 7.2266e-01,\n", + " -1.8945e-01, -3.9844e-01],\n", + " [ 3.2422e-01, 2.5977e-01, 5.0391e-01, ..., 3.2227e-01,\n", + " -1.6113e-01, -4.8584e-02],\n", + " ...,\n", + " [-3.5938e-01, -3.9551e-02, 4.1602e-01, ..., -7.0703e-01,\n", + " -2.7734e-01, 3.1836e-01],\n", + " [-1.8066e-01, 1.0107e-01, 6.7188e-01, ..., -8.3594e-01,\n", + " -3.6719e-01, 6.9141e-01],\n", + " [ 3.5889e-02, -1.4746e-01, -3.8672e-01, ..., 1.5234e-01,\n", + " 1.0938e-01, -3.7537e-03]],\n", + "\n", + " [[-2.6953e-01, -5.6396e-02, 7.7344e-01, ..., 1.9238e-01,\n", + " -5.3516e-01, -3.5547e-01],\n", + " [-3.3398e-01, -4.7461e-01, 4.4922e-01, ..., -1.8066e-01,\n", + " 1.4844e-01, -3.8281e-01],\n", + " [ 3.9453e-01, -6.1719e-01, -1.3770e-01, ..., -2.0020e-01,\n", + " -4.0820e-01, -2.7930e-01],\n", + " ...,\n", + " [ 2.2070e-01, 1.4305e-04, 1.1719e+00, ..., 1.2656e+00,\n", + " -4.5898e-01, 1.4648e-01],\n", + " [ 4.6484e-01, -3.6328e-01, -5.2185e-03, ..., -8.4375e-01,\n", + " -2.8320e-01, 1.6113e-01],\n", + " [-4.2188e-01, -8.8379e-02, 7.1484e-01, ..., -1.6895e-01,\n", + " 4.6692e-03, 2.6367e-01]],\n", + "\n", + " [[-4.7656e-01, -4.4531e-01, -4.6484e-01, ..., -8.3984e-01,\n", + " -9.9121e-02, -1.2344e+00],\n", + " [-1.9336e-01, -5.3906e-01, -7.3047e-01, ..., -5.5469e-01,\n", + " 3.1055e-01, -1.0469e+00],\n", + " [-3.2812e-01, 1.9688e+00, 1.3359e+00, ..., 1.1328e+00,\n", + " -1.8672e+00, -5.5078e-01],\n", + " ...,\n", + " [ 2.9102e-01, -4.1016e-01, 5.8203e-01, ..., -6.7969e-01,\n", + " 4.9414e-01, -4.7070e-01],\n", + " [ 3.4180e-02, -5.4297e-01, 1.2146e-02, ..., -3.5156e-01,\n", + " 1.4551e-01, -5.6152e-02],\n", + " [-3.9062e-02, -1.6968e-02, 4.9609e-01, ..., -3.7500e-01,\n", + " 1.4258e-01, 2.2363e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 3.8574e-02, -7.7209e-03, 1.1536e-02, ..., -1.1719e-01,\n", + " 1.9844e+00, 7.0312e-02],\n", + " [ 1.2085e-02, -1.8921e-02, 3.3264e-03, ..., -1.1572e-01,\n", + " 1.9844e+00, 7.5684e-02],\n", + " [ 3.0273e-01, 3.0469e-01, -5.8203e-01, ..., 6.6797e-01,\n", + " 6.1719e-01, 6.2500e-01],\n", + " ...,\n", + " [ 8.0469e-01, 2.1250e+00, -1.4531e+00, ..., 5.5859e-01,\n", + " -2.6094e+00, 8.5547e-01],\n", + " [ 1.2969e+00, 1.2188e+00, -3.3594e-01, ..., 5.3906e-01,\n", + " -2.0156e+00, 1.0156e+00],\n", + " [-3.0664e-01, 2.9492e-01, 5.0000e-01, ..., 4.3125e+00,\n", + " 5.0781e-01, -3.9648e-01]],\n", + "\n", + " [[-1.8555e-02, 5.2795e-03, -6.4087e-03, ..., -1.0938e+00,\n", + " -2.5312e+00, 1.1719e+00],\n", + " [ 1.8311e-04, -2.0142e-02, 2.0599e-03, ..., -1.0938e+00,\n", + " -2.5312e+00, 1.1797e+00],\n", + " [-1.1572e-01, -1.7969e-01, -5.6250e-01, ..., -1.2188e+00,\n", + " -2.2031e+00, 1.1484e+00],\n", + " ...,\n", + " [ 8.6328e-01, -5.6641e-01, -1.8047e+00, ..., -3.3594e-01,\n", + " -2.2500e+00, 8.2812e-01],\n", + " [-2.5469e+00, 3.2031e-01, -1.9141e-01, ..., 1.0986e-01,\n", + " -1.0703e+00, 1.6172e+00],\n", + " [-1.1875e+00, 1.2305e-01, 1.7344e+00, ..., -8.9355e-02,\n", + " -6.0547e-01, 1.1719e+00]],\n", + "\n", + " [[ 6.4087e-03, 4.2725e-03, 2.8839e-03, ..., -3.8672e-01,\n", + " -9.6191e-02, 5.5859e-01],\n", + " [ 1.6357e-02, 1.1292e-02, 5.5542e-03, ..., -3.8281e-01,\n", + " -9.6191e-02, 5.5469e-01],\n", + " [-2.5781e-01, 3.0273e-01, -3.2227e-01, ..., -5.5469e-01,\n", + " -1.5918e-01, 4.6875e-01],\n", + " ...,\n", + " [ 1.2969e+00, 9.2969e-01, -1.0781e+00, ..., -1.9453e+00,\n", + " -9.0234e-01, 6.0156e-01],\n", + " [-6.8750e-01, 2.0625e+00, 2.3438e-02, ..., -5.2734e-01,\n", + " 2.4023e-01, 3.4180e-01],\n", + " [-9.7656e-01, 4.2188e-01, 8.8281e-01, ..., -1.6504e-01,\n", + " -6.9922e-01, 1.2578e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 4.1809e-03, -4.1016e-02, -4.3392e-05, ..., 2.0996e-02,\n", + " -7.3730e-02, -5.6458e-03],\n", + " [ 1.3828e-04, -4.1260e-02, 1.0834e-03, ..., 2.1729e-02,\n", + " -7.3730e-02, -8.3618e-03],\n", + " [ 1.1353e-02, -1.6699e-01, -8.0566e-02, ..., -1.5503e-02,\n", + " -8.4473e-02, -9.2285e-02],\n", + " ...,\n", + " [-4.9609e-01, -7.0801e-02, -1.6992e-01, ..., 7.6172e-01,\n", + " 2.4512e-01, 2.6953e-01],\n", + " [ 1.3203e+00, 9.3750e-02, -5.7422e-01, ..., 3.7305e-01,\n", + " 3.5352e-01, -5.1953e-01],\n", + " [ 2.1484e-01, 1.9531e-01, 1.6211e-01, ..., -3.6523e-01,\n", + " -7.7734e-01, 3.1641e-01]],\n", + "\n", + " [[-1.7578e-02, -7.8125e-03, -3.0762e-02, ..., -1.5137e-02,\n", + " 1.2024e-02, -1.4725e-03],\n", + " [-1.6968e-02, -7.6904e-03, -2.8931e-02, ..., -1.8311e-02,\n", + " 1.0620e-02, 2.3041e-03],\n", + " [ 1.8457e-01, 2.1387e-01, 7.8735e-03, ..., -3.6133e-01,\n", + " -5.0049e-02, 1.7383e-01],\n", + " ...,\n", + " [ 8.0566e-02, -2.9297e-01, 4.5312e-01, ..., -6.8848e-02,\n", + " 1.3379e-01, -1.6895e-01],\n", + " [-1.3672e-01, -5.1562e-01, 4.0527e-02, ..., 1.3516e+00,\n", + " -1.8555e-01, 3.6133e-01],\n", + " [ 4.4141e-01, 4.1406e-01, -1.2451e-01, ..., 5.6641e-01,\n", + " -2.0801e-01, 7.8613e-02]],\n", + "\n", + " [[ 1.4526e-02, 2.5513e-02, -3.7109e-02, ..., -2.0752e-02,\n", + " -4.9316e-02, -1.6968e-02],\n", + " [ 1.3794e-02, 2.6855e-02, -3.3936e-02, ..., -1.9043e-02,\n", + " -4.4434e-02, -1.7456e-02],\n", + " [-2.3560e-02, -1.1279e-01, 2.0703e-01, ..., 4.1504e-03,\n", + " 3.8330e-02, -9.8145e-02],\n", + " ...,\n", + " [ 8.0078e-02, -5.5176e-02, -1.5039e-01, ..., 4.9609e-01,\n", + " 2.0508e-01, -5.1953e-01],\n", + " [ 7.4609e-01, 3.0078e-01, 6.1328e-01, ..., 9.3359e-01,\n", + " 6.3477e-02, -5.3906e-01],\n", + " [-4.0283e-02, -7.8735e-03, 2.0605e-01, ..., 2.3242e-01,\n", + " 1.2793e-01, 3.2227e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 1.5076e-02, 1.1414e-02, -1.9165e-02, ..., -2.2031e+00,\n", + " 2.1777e-01, 5.4688e-01],\n", + " [ 1.2451e-02, 4.2419e-03, -1.7090e-02, ..., -2.2031e+00,\n", + " 2.1875e-01, 5.4297e-01],\n", + " [-4.3750e-01, -6.4453e-01, -2.0117e-01, ..., -1.2109e+00,\n", + " 5.7031e-01, 9.1406e-01],\n", + " ...,\n", + " [-3.4688e+00, -2.0469e+00, 2.8750e+00, ..., 2.8750e+00,\n", + " -4.4189e-02, 1.2656e+00],\n", + " [-2.3906e+00, -2.2031e+00, 3.4375e+00, ..., 3.5312e+00,\n", + " -5.5469e-01, 1.3594e+00],\n", + " [ 4.5410e-02, -1.4609e+00, -7.4219e-02, ..., 2.3750e+00,\n", + " -3.8281e-01, 3.9844e-01]],\n", + "\n", + " [[ 1.1658e-02, -3.0670e-03, -2.2705e-02, ..., -1.2598e-01,\n", + " 3.0625e+00, -1.4297e+00],\n", + " [ 9.7656e-03, 6.3477e-03, -7.3242e-03, ..., -1.2598e-01,\n", + " 3.0625e+00, -1.4297e+00],\n", + " [-4.2188e-01, 4.9609e-01, 4.8047e-01, ..., 1.7334e-02,\n", + " 2.3906e+00, -1.1016e+00],\n", + " ...,\n", + " [-3.0469e+00, -1.2031e+00, -2.6719e+00, ..., 7.1875e-01,\n", + " -1.3750e+00, -5.9570e-02],\n", + " [-8.4375e-01, -3.9062e-01, -2.0625e+00, ..., -4.3750e-01,\n", + " -1.4531e+00, 2.1387e-01],\n", + " [ 2.6953e-01, 1.4062e+00, -1.2812e+00, ..., 5.8984e-01,\n", + " 1.4648e-02, -1.6484e+00]],\n", + "\n", + " [[-7.9956e-03, -1.4709e-02, 1.6928e-05, ..., -7.2266e-01,\n", + " 2.3281e+00, -4.5117e-01],\n", + " [-8.0566e-03, -1.7334e-02, 2.8534e-03, ..., -7.1875e-01,\n", + " 2.3281e+00, -4.4922e-01],\n", + " [-3.9453e-01, 3.7695e-01, -3.3594e-01, ..., -7.9297e-01,\n", + " 2.0469e+00, -3.3789e-01],\n", + " ...,\n", + " [ 1.4375e+00, -3.9648e-01, -1.5078e+00, ..., -8.6719e-01,\n", + " -4.1797e-01, -6.4453e-01],\n", + " [-1.2656e+00, 1.0469e+00, -5.2734e-01, ..., 2.8320e-01,\n", + " 2.9102e-01, -8.8672e-01],\n", + " [-2.5391e-01, 6.9531e-01, -1.2500e-01, ..., -1.2500e-01,\n", + " 6.7578e-01, -2.0215e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-0.0271, -0.0334, -0.0811, ..., 0.0082, 0.0082, 0.0791],\n", + " [-0.0256, -0.0308, -0.0845, ..., 0.0075, 0.0063, 0.0776],\n", + " [ 0.0513, 0.1982, -0.1157, ..., 0.0786, 0.1572, 0.1289],\n", + " ...,\n", + " [-0.0918, 0.6680, -0.2637, ..., -0.1123, 0.0072, -0.0664],\n", + " [ 0.1079, -0.1836, 0.1123, ..., -0.0223, -0.0850, 0.3848],\n", + " [ 0.1963, -0.0708, 0.0493, ..., 0.2969, -0.2119, 0.5000]],\n", + "\n", + " [[ 0.0654, -0.0034, -0.0083, ..., -0.0075, 0.0064, -0.0352],\n", + " [ 0.0635, -0.0043, -0.0057, ..., -0.0056, 0.0052, -0.0349],\n", + " [ 0.1631, -0.2520, -0.1826, ..., -0.0383, -0.0552, -0.0211],\n", + " ...,\n", + " [-0.9375, 0.6992, 0.2930, ..., 1.1953, -0.0481, -0.3262],\n", + " [-0.2910, 0.6992, 0.0796, ..., 0.2715, -0.1826, 0.0510],\n", + " [-0.2031, 0.0576, -0.0120, ..., -0.1504, 0.3828, -0.2490]],\n", + "\n", + " [[-0.0222, 0.0327, -0.0016, ..., -0.0356, 0.0019, -0.0391],\n", + " [-0.0217, 0.0347, -0.0026, ..., -0.0339, 0.0032, -0.0383],\n", + " [ 0.1152, 0.1836, -0.0237, ..., 0.1445, -0.1504, -0.0752],\n", + " ...,\n", + " [ 0.3652, -0.3477, -0.4922, ..., -0.1748, 0.0242, 0.5625],\n", + " [ 0.2402, 0.0518, 0.2061, ..., 0.2051, -0.3359, 0.4512],\n", + " [ 0.1748, -0.3574, -0.2891, ..., 0.0222, -0.6719, 0.2021]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)), (tensor([[[[-9.7656e-03, 4.9438e-03, -1.8433e-02, ..., -1.3086e-01,\n", + " -1.4766e+00, -6.6016e-01],\n", + " [-2.9297e-03, 1.4526e-02, -1.4587e-02, ..., -1.3281e-01,\n", + " -1.4688e+00, -6.5234e-01],\n", + " [ 2.4219e-01, 3.4766e-01, -4.7119e-02, ..., 1.8848e-01,\n", + " -2.0625e+00, -1.1328e+00],\n", + " ...,\n", + " [-6.7969e-01, -6.9531e-01, 8.1250e-01, ..., 4.5776e-03,\n", + " -1.3438e+00, -1.8047e+00],\n", + " [ 1.3750e+00, 8.8672e-01, 1.1328e-01, ..., 4.5410e-02,\n", + " -6.5234e-01, -1.4453e+00],\n", + " [ 1.8066e-02, -1.2793e-01, -4.5508e-01, ..., -1.2031e+00,\n", + " -1.4941e-01, -1.5547e+00]],\n", + "\n", + " [[-4.0283e-03, 8.9111e-03, 1.2756e-02, ..., -7.0703e-01,\n", + " 2.4688e+00, 2.0156e+00],\n", + " [ 1.3184e-02, 1.3428e-02, 1.7090e-03, ..., -6.9531e-01,\n", + " 2.4531e+00, 2.0156e+00],\n", + " [-3.9062e-02, -1.4355e-01, -1.9531e-03, ..., 3.8086e-01,\n", + " 8.1641e-01, -3.0859e-01],\n", + " ...,\n", + " [ 2.5000e-01, 9.4922e-01, 1.8984e+00, ..., 3.7812e+00,\n", + " -1.7656e+00, -3.1875e+00],\n", + " [ 1.4062e+00, 1.0312e+00, 1.2031e+00, ..., 1.0938e+00,\n", + " -4.7188e+00, -2.1562e+00],\n", + " [-3.8672e-01, -1.7383e-01, -2.8711e-01, ..., 3.0312e+00,\n", + " -4.7812e+00, -3.0625e+00]],\n", + "\n", + " [[-9.2163e-03, -2.7954e-02, 3.9978e-03, ..., -7.7734e-01,\n", + " -2.1875e+00, -1.7188e+00],\n", + " [-7.6294e-04, -2.3071e-02, -5.1117e-04, ..., -7.8516e-01,\n", + " -2.1875e+00, -1.7188e+00],\n", + " [ 6.1768e-02, -2.5000e-01, -7.8613e-02, ..., -2.1562e+00,\n", + " -1.4453e+00, -7.3047e-01],\n", + " ...,\n", + " [-1.4141e+00, -8.6719e-01, 1.8438e+00, ..., -2.9688e+00,\n", + " 1.7969e+00, 1.8433e-02],\n", + " [ 2.7734e-01, -1.1094e+00, 1.4141e+00, ..., -4.1875e+00,\n", + " 1.2266e+00, 3.8672e-01],\n", + " [ 3.1836e-01, 1.5332e-01, 4.2236e-02, ..., -8.9453e-01,\n", + " -1.9824e-01, 1.8984e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 3.5889e-02, -5.7861e-02, 5.0049e-02, ..., 2.2168e-01,\n", + " -1.6724e-02, 1.8066e-02],\n", + " [ 3.9551e-02, -5.6396e-02, 4.8340e-02, ..., 2.2070e-01,\n", + " -1.3550e-02, 1.8311e-02],\n", + " [ 2.3340e-01, -1.4258e-01, -7.5684e-02, ..., 3.6523e-01,\n", + " 4.1504e-02, 1.6602e-02],\n", + " ...,\n", + " [-3.8477e-01, -3.7305e-01, -4.6484e-01, ..., 6.1719e-01,\n", + " 9.3359e-01, -4.9609e-01],\n", + " [-6.3672e-01, -7.8125e-01, 4.1992e-01, ..., -4.3945e-01,\n", + " 1.4219e+00, -1.3203e+00],\n", + " [ 7.2266e-01, 2.5195e-01, 7.1875e-01, ..., -4.7852e-01,\n", + " 6.0156e-01, -7.8125e-01]],\n", + "\n", + " [[-2.6245e-03, 2.5269e-02, 9.7656e-03, ..., 1.0437e-02,\n", + " -1.3977e-02, 8.3008e-03],\n", + " [-5.2643e-04, 2.8564e-02, 1.0071e-02, ..., 1.3367e-02,\n", + " -1.4404e-02, 8.8501e-03],\n", + " [-3.4766e-01, 2.3828e-01, 5.7373e-02, ..., -1.3086e-01,\n", + " -1.6992e-01, -1.6602e-01],\n", + " ...,\n", + " [-1.5137e-02, -4.2578e-01, 3.5742e-01, ..., 5.0391e-01,\n", + " -7.5000e-01, -7.9688e-01],\n", + " [ 5.2344e-01, -3.6523e-01, 1.8262e-01, ..., 3.2422e-01,\n", + " 3.7354e-02, -2.4512e-01],\n", + " [-2.6953e-01, 8.8672e-01, -1.5527e-01, ..., 2.4023e-01,\n", + " -3.9258e-01, 4.8096e-02]],\n", + "\n", + " [[ 5.3467e-02, 3.6621e-02, -3.7842e-02, ..., 4.4678e-02,\n", + " -9.2506e-05, -9.7656e-03],\n", + " [ 5.6152e-02, 3.9062e-02, -3.9795e-02, ..., 4.3213e-02,\n", + " -2.7008e-03, -1.0254e-02],\n", + " [-1.2598e-01, 1.3574e-01, -2.1484e-01, ..., -6.8848e-02,\n", + " 1.5234e-01, -5.1514e-02],\n", + " ...,\n", + " [-4.6875e-01, -7.4219e-01, 2.4121e-01, ..., 5.3125e-01,\n", + " -5.4688e-01, -8.4375e-01],\n", + " [-7.5391e-01, -6.4453e-01, 3.1641e-01, ..., 1.4688e+00,\n", + " -2.7148e-01, -6.8750e-01],\n", + " [ 2.8516e-01, 2.1387e-01, -4.4141e-01, ..., 3.5352e-01,\n", + " 8.3984e-01, -9.2578e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-1.9073e-03, -2.0599e-03, 2.0996e-02, ..., -5.5908e-02,\n", + " -5.2812e+00, -1.1406e+00],\n", + " [-1.2329e-02, -2.6093e-03, 1.0620e-02, ..., -6.0547e-02,\n", + " -5.2812e+00, -1.1328e+00],\n", + " [ 4.3164e-01, 1.7480e-01, -3.0859e-01, ..., 3.8757e-03,\n", + " -3.7344e+00, -1.3672e+00],\n", + " ...,\n", + " [ 2.1562e+00, 2.1875e+00, 4.9414e-01, ..., 1.8750e-01,\n", + " 3.1562e+00, -2.1719e+00],\n", + " [ 1.5547e+00, 1.1797e+00, 9.3750e-01, ..., -9.5703e-02,\n", + " 3.3594e+00, -2.6094e+00],\n", + " [-3.4570e-01, -1.2695e-02, 4.4531e-01, ..., 6.8359e-02,\n", + " -4.7656e-01, -1.9766e+00]],\n", + "\n", + " [[-7.0496e-03, 2.5391e-02, 5.1575e-03, ..., -1.3047e+00,\n", + " 3.0781e+00, 1.4922e+00],\n", + " [ 7.5073e-03, 4.8828e-03, 1.2695e-02, ..., -1.3125e+00,\n", + " 3.0938e+00, 1.4922e+00],\n", + " [ 3.6523e-01, 2.6562e-01, 3.2031e-01, ..., -1.9922e+00,\n", + " 2.1406e+00, 7.1484e-01],\n", + " ...,\n", + " [ 1.9453e+00, -2.2344e+00, 6.6016e-01, ..., -2.4375e+00,\n", + " -1.4688e+00, -1.3359e+00],\n", + " [ 9.7266e-01, -6.2109e-01, -4.7852e-01, ..., -1.4609e+00,\n", + " -7.6953e-01, -2.7500e+00],\n", + " [-3.9258e-01, 8.2031e-02, -4.8242e-01, ..., 4.3555e-01,\n", + " 1.3203e+00, 2.2031e+00]],\n", + "\n", + " [[ 1.5747e-02, -1.2329e-02, -6.5613e-03, ..., -1.1250e+00,\n", + " 1.1562e+00, 2.0000e+00],\n", + " [ 2.5024e-02, 3.0518e-03, -7.9346e-03, ..., -1.1328e+00,\n", + " 1.1484e+00, 1.9922e+00],\n", + " [-8.2031e-02, 1.3867e-01, 2.2656e-01, ..., -1.1484e+00,\n", + " -4.5898e-01, 3.6914e-01],\n", + " ...,\n", + " [ 6.0938e-01, 1.0547e+00, -1.3047e+00, ..., -1.8125e+00,\n", + " -4.2188e+00, -5.5938e+00],\n", + " [-3.3203e-01, 9.1406e-01, -1.5469e+00, ..., -1.4844e+00,\n", + " -3.2969e+00, -4.5625e+00],\n", + " [-1.6797e-01, 8.3496e-02, -1.2891e-01, ..., -2.8125e+00,\n", + " -5.5078e-01, -6.7188e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-5.3406e-03, -3.4424e-02, 5.4199e-02, ..., -2.1484e-02,\n", + " 1.5198e-02, -2.1973e-02],\n", + " [-5.6458e-03, -3.9062e-02, 5.5908e-02, ..., -2.1484e-02,\n", + " 1.2695e-02, -1.8433e-02],\n", + " [ 2.2852e-01, 2.4414e-01, -2.5586e-01, ..., 6.9824e-02,\n", + " -1.1377e-01, 2.7148e-01],\n", + " ...,\n", + " [-2.6953e-01, -4.0039e-01, -3.8867e-01, ..., -2.2949e-01,\n", + " -8.1250e-01, -1.8164e-01],\n", + " [-8.2812e-01, -7.6172e-01, 7.8125e-02, ..., -2.5879e-02,\n", + " -6.8848e-02, 8.3618e-03],\n", + " [ 5.9375e-01, 2.2266e-01, 5.4932e-02, ..., 1.8848e-01,\n", + " -9.5215e-02, -3.4180e-01]],\n", + "\n", + " [[-3.6865e-02, 2.5757e-02, -6.7383e-02, ..., 6.2988e-02,\n", + " 7.7637e-02, 1.1108e-02],\n", + " [-3.8574e-02, 2.5635e-02, -6.8359e-02, ..., 6.2500e-02,\n", + " 7.9102e-02, 8.1177e-03],\n", + " [ 2.5558e-04, 4.7119e-02, 3.8330e-02, ..., 9.5703e-02,\n", + " -1.1816e-01, -8.7891e-03],\n", + " ...,\n", + " [-9.6094e-01, -1.1328e+00, 1.1797e+00, ..., -7.1289e-02,\n", + " -2.2559e-01, -2.8516e-01],\n", + " [ 6.4844e-01, -3.2031e-01, 2.0117e-01, ..., 6.0547e-01,\n", + " -1.8750e-01, -4.1406e-01],\n", + " [ 5.2344e-01, 1.0156e+00, 4.0625e-01, ..., -3.1836e-01,\n", + " -3.6719e-01, -9.3262e-02]],\n", + "\n", + " [[ 8.7402e-02, 3.3447e-02, -5.7861e-02, ..., 1.9629e-01,\n", + " 1.9897e-02, -2.2031e+00],\n", + " [ 8.1055e-02, 2.8198e-02, -5.7129e-02, ..., 1.9727e-01,\n", + " 2.2461e-02, -2.2188e+00],\n", + " [ 1.6699e-01, 8.9355e-02, 7.6172e-02, ..., -1.2158e-01,\n", + " 3.7305e-01, -1.7422e+00],\n", + " ...,\n", + " [-6.6797e-01, -4.5117e-01, 3.9844e-01, ..., -1.0938e+00,\n", + " -6.9141e-01, 4.1260e-02],\n", + " [-1.8828e+00, -5.1514e-02, 6.1719e-01, ..., -1.4531e+00,\n", + " -3.5352e-01, -6.6895e-02],\n", + " [-1.4844e+00, 8.2812e-01, -1.4355e-01, ..., -8.4375e-01,\n", + " 4.3945e-02, 1.5391e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-1.2131e-03, -1.1673e-03, -5.3406e-03, ..., -1.7812e+00,\n", + " 8.1641e-01, 9.4141e-01],\n", + " [ 2.9602e-03, 4.4861e-03, -1.3855e-02, ..., -1.7812e+00,\n", + " 8.2031e-01, 9.3750e-01],\n", + " [-7.8516e-01, -5.0781e-01, -3.4766e-01, ..., -1.4609e+00,\n", + " 1.2109e+00, 1.0625e+00],\n", + " ...,\n", + " [ 2.8125e+00, -1.6211e-01, -7.8906e-01, ..., -5.4688e-01,\n", + " 1.6094e+00, 1.1875e+00],\n", + " [-2.4219e+00, -1.4453e+00, 7.3438e-01, ..., -3.7305e-01,\n", + " 9.8047e-01, 2.1250e+00],\n", + " [-5.6250e+00, -2.2656e+00, 1.7891e+00, ..., -2.0625e+00,\n", + " 1.1328e+00, 1.9844e+00]],\n", + "\n", + " [[-1.4526e-02, 7.8125e-03, -1.9043e-02, ..., -4.2500e+00,\n", + " 1.6602e-02, -6.0547e-01],\n", + " [ 5.9509e-03, 9.3384e-03, -1.9897e-02, ..., -4.2500e+00,\n", + " 1.7700e-02, -6.0938e-01],\n", + " [-6.6406e-02, -2.5586e-01, -8.2031e-02, ..., -2.2188e+00,\n", + " -1.7676e-01, -8.1250e-01],\n", + " ...,\n", + " [ 2.4375e+00, 1.3984e+00, 9.0625e-01, ..., 2.4062e+00,\n", + " 9.8828e-01, -6.6016e-01],\n", + " [-7.8906e-01, 1.5234e-01, 1.9531e-02, ..., 2.5781e+00,\n", + " 1.4375e+00, -6.0156e-01],\n", + " [ 3.1006e-02, 2.0020e-02, 1.0391e+00, ..., 1.1875e+00,\n", + " 1.0596e-01, -1.2031e+00]],\n", + "\n", + " [[ 9.9182e-04, 4.0894e-03, 6.0730e-03, ..., 1.3750e+00,\n", + " 1.7656e+00, 3.2812e-01],\n", + " [ 1.1414e-02, -5.8289e-03, 2.5940e-03, ..., 1.3828e+00,\n", + " 1.7656e+00, 3.1836e-01],\n", + " [ 1.6113e-01, 2.5781e-01, 1.7480e-01, ..., 1.4531e+00,\n", + " 1.1953e+00, -8.9844e-02],\n", + " ...,\n", + " [-1.0391e+00, 6.7188e-01, 2.0781e+00, ..., 1.7188e+00,\n", + " -1.3281e+00, 9.8047e-01],\n", + " [ 1.5781e+00, 1.5781e+00, 3.9062e-01, ..., 1.9766e+00,\n", + " -3.9844e-01, 2.3926e-01],\n", + " [ 4.2188e-01, 6.1328e-01, -2.5195e-01, ..., 1.1094e+00,\n", + " 8.6328e-01, -1.1016e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-6.3965e-02, -4.3701e-02, 1.0303e-01, ..., 1.0071e-02,\n", + " 3.1891e-03, -6.8359e-02],\n", + " [-6.2500e-02, -3.9307e-02, 1.0107e-01, ..., 9.8267e-03,\n", + " 1.1978e-03, -6.5918e-02],\n", + " [ 1.2695e-01, 6.0791e-02, 1.7090e-01, ..., -2.2559e-01,\n", + " -2.3438e-02, -5.9326e-02],\n", + " ...,\n", + " [-2.4023e-01, 3.4668e-02, 6.2109e-01, ..., -8.5156e-01,\n", + " 3.0273e-01, -1.3867e-01],\n", + " [-6.7871e-02, -1.5723e-01, 8.7891e-01, ..., -3.6328e-01,\n", + " 1.3379e-01, 9.6798e-05],\n", + " [ 2.8125e-01, 2.5195e-01, 9.8438e-01, ..., -2.6562e-01,\n", + " 2.3926e-01, -1.1523e-01]],\n", + "\n", + " [[ 6.5918e-03, 9.0332e-03, 2.0386e-02, ..., 1.9531e-02,\n", + " 1.6357e-02, 2.6123e-02],\n", + " [ 4.3640e-03, 1.2512e-02, 2.2339e-02, ..., 1.7822e-02,\n", + " 1.7822e-02, 2.6489e-02],\n", + " [-8.4961e-02, 1.6699e-01, -2.3804e-02, ..., -3.0884e-02,\n", + " 1.4062e-01, 4.4922e-02],\n", + " ...,\n", + " [ 3.5742e-01, 1.9238e-01, 8.0859e-01, ..., 1.0791e-01,\n", + " -5.6250e-01, -9.7168e-02],\n", + " [ 5.2979e-02, 7.3047e-01, 1.7266e+00, ..., 4.9805e-01,\n", + " -8.0078e-02, -2.2754e-01],\n", + " [-5.7031e-01, -6.9531e-01, -6.5918e-02, ..., -4.2188e-01,\n", + " 7.8613e-02, 2.1729e-02]],\n", + "\n", + " [[-6.9336e-02, -2.6172e-01, 3.2959e-02, ..., 4.5898e-02,\n", + " 1.1292e-02, 2.6367e-02],\n", + " [-6.9336e-02, -2.5781e-01, 3.1738e-02, ..., 4.5410e-02,\n", + " 9.4604e-03, 2.3560e-02],\n", + " [-1.6797e-01, 1.6113e-01, -5.4199e-02, ..., 9.8145e-02,\n", + " 1.5527e-01, -8.3008e-02],\n", + " ...,\n", + " [ 6.1719e-01, 4.8242e-01, -5.5859e-01, ..., -8.2031e-02,\n", + " -3.5156e-01, -5.2979e-02],\n", + " [ 3.1836e-01, 1.1406e+00, -7.4219e-01, ..., -6.0938e-01,\n", + " -3.0078e-01, 2.9492e-01],\n", + " [ 1.0156e-01, 3.4766e-01, -5.5469e-01, ..., 5.8594e-01,\n", + " -1.4746e-01, 2.4609e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 3.8300e-03, 3.1128e-03, 1.2512e-02, ..., 3.0273e-01,\n", + " -3.8906e+00, 6.0938e-01],\n", + " [ 1.7242e-03, 1.2268e-02, 1.6357e-02, ..., 3.0078e-01,\n", + " -3.8906e+00, 6.1328e-01],\n", + " [-3.1738e-02, 1.8164e-01, -5.6885e-02, ..., 6.5625e-01,\n", + " -1.7969e+00, 1.9922e-01],\n", + " ...,\n", + " [ 4.3359e-01, 7.1484e-01, 5.5859e-01, ..., 4.9023e-01,\n", + " 2.0625e+00, -7.3047e-01],\n", + " [-5.5078e-01, 1.0391e+00, 6.5625e-01, ..., 6.7578e-01,\n", + " 2.2031e+00, 7.1094e-01],\n", + " [ 1.1914e-01, -8.7402e-02, 1.4746e-01, ..., 7.4219e-01,\n", + " 4.4336e-01, 2.6562e+00]],\n", + "\n", + " [[ 1.5015e-02, -1.5442e-02, 2.3804e-02, ..., -5.5078e-01,\n", + " -1.2266e+00, -1.6875e+00],\n", + " [-8.2397e-03, -3.6011e-03, 2.4292e-02, ..., -5.5859e-01,\n", + " -1.2266e+00, -1.6797e+00],\n", + " [-2.5879e-02, -8.1055e-02, 2.5781e-01, ..., -1.1641e+00,\n", + " 1.8652e-01, 8.7500e-01],\n", + " ...,\n", + " [ 9.9219e-01, -1.2188e+00, -5.7812e-01, ..., -2.3281e+00,\n", + " 3.7188e+00, 1.1625e+01],\n", + " [ 6.3672e-01, -1.2031e+00, -1.3047e+00, ..., -8.3203e-01,\n", + " 6.0312e+00, 1.0500e+01],\n", + " [-1.1279e-01, -8.3984e-01, -1.3086e-01, ..., -5.6562e+00,\n", + " 2.7188e+00, 7.4062e+00]],\n", + "\n", + " [[ 1.0910e-03, 9.1553e-03, 6.9275e-03, ..., -5.4199e-02,\n", + " -3.7656e+00, 1.9141e-01],\n", + " [ 3.8147e-03, 8.7280e-03, 9.7656e-03, ..., -6.0547e-02,\n", + " -3.7656e+00, 1.8457e-01],\n", + " [ 2.0605e-01, -2.3340e-01, 5.9375e-01, ..., 2.9785e-02,\n", + " -1.5859e+00, -1.6016e-01],\n", + " ...,\n", + " [ 2.9531e+00, -3.3984e-01, -3.6328e-01, ..., 1.4609e+00,\n", + " 4.0938e+00, 3.8086e-01],\n", + " [ 7.0312e-01, -7.5781e-01, -1.6875e+00, ..., 7.7734e-01,\n", + " 3.7344e+00, 7.5781e-01],\n", + " [-5.8594e-01, -1.3594e+00, -8.5547e-01, ..., -2.2969e+00,\n", + " 1.6250e+00, 7.8906e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-1.5625e-02, -8.0078e-02, 6.2256e-03, ..., 1.8921e-02,\n", + " 1.6309e-01, 8.1543e-02],\n", + " [-1.5198e-02, -8.2031e-02, 4.6997e-03, ..., 1.4465e-02,\n", + " 1.6406e-01, 7.7637e-02],\n", + " [ 1.7090e-01, -2.6562e-01, 1.9141e-01, ..., 8.3496e-02,\n", + " 6.4941e-02, -2.1851e-02],\n", + " ...,\n", + " [ 8.3008e-02, -5.5859e-01, 1.0703e+00, ..., 2.3594e+00,\n", + " 3.0273e-01, 2.0605e-01],\n", + " [-6.2500e-01, -3.7695e-01, 2.6562e-01, ..., 1.6094e+00,\n", + " 1.0547e+00, -2.3730e-01],\n", + " [ 9.6875e-01, 1.2158e-01, -4.2773e-01, ..., 7.7344e-01,\n", + " 1.6895e-01, -2.2656e-01]],\n", + "\n", + " [[-2.6123e-02, -1.7700e-02, 4.2725e-02, ..., -6.7139e-03,\n", + " -2.3682e-02, 5.6152e-02],\n", + " [-2.9297e-02, -1.6479e-02, 3.8574e-02, ..., -2.3804e-03,\n", + " -2.4536e-02, 5.4443e-02],\n", + " [ 3.8477e-01, 6.2012e-02, 1.4258e-01, ..., 9.7656e-02,\n", + " -7.2266e-02, -2.4805e-01],\n", + " ...,\n", + " [ 4.0820e-01, 4.1992e-01, 2.4292e-02, ..., 1.2061e-01,\n", + " 5.3906e-01, -3.6328e-01],\n", + " [-3.3789e-01, 6.5234e-01, 5.3955e-02, ..., 8.8867e-02,\n", + " 5.7422e-01, -3.3008e-01],\n", + " [ 1.0303e-01, 2.9297e-01, 5.7031e-01, ..., -6.9531e-01,\n", + " 1.2109e+00, -6.0938e-01]],\n", + "\n", + " [[ 2.2602e-04, -3.5889e-02, 1.6556e-03, ..., 8.0078e-02,\n", + " 1.4099e-02, -1.2268e-02],\n", + " [ 2.5330e-03, -3.1006e-02, 1.1902e-03, ..., 8.1055e-02,\n", + " 1.5503e-02, -1.1169e-02],\n", + " [-2.5195e-01, 3.1982e-02, 9.9121e-02, ..., 5.3711e-02,\n", + " -1.6235e-02, -2.0996e-02],\n", + " ...,\n", + " [-6.2891e-01, 2.0801e-01, 4.1992e-01, ..., 1.0703e+00,\n", + " -2.7222e-02, 1.6484e+00],\n", + " [-9.0234e-01, 1.0547e+00, 3.0859e-01, ..., -5.4688e-02,\n", + " -4.9219e-01, 9.4922e-01],\n", + " [ 2.2070e-01, 2.4902e-01, 3.3008e-01, ..., -4.3555e-01,\n", + " 1.7944e-02, 5.9766e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-1.2085e-02, -5.2795e-03, 1.9073e-03, ..., -8.6719e-01,\n", + " -2.7344e-01, -7.1484e-01],\n", + " [-1.9897e-02, -5.2490e-03, -6.2256e-03, ..., -8.6328e-01,\n", + " -2.6758e-01, -7.1484e-01],\n", + " [-3.6133e-02, 5.8594e-01, 1.7871e-01, ..., -1.5000e+00,\n", + " 2.8442e-02, -5.1562e-01],\n", + " ...,\n", + " [-1.9531e+00, 4.9414e-01, -5.3125e-01, ..., -2.2812e+00,\n", + " 1.0986e-01, -8.2422e-01],\n", + " [ 1.2500e-01, 1.5234e+00, -5.8594e-01, ..., -1.5625e+00,\n", + " -3.5352e-01, -1.4062e+00],\n", + " [ 3.4180e-02, 1.2891e-01, -3.9062e-02, ..., -1.1875e+00,\n", + " 2.1875e-01, -9.7656e-01]],\n", + "\n", + " [[-1.6357e-02, 8.3618e-03, -6.2866e-03, ..., -4.1016e-02,\n", + " 1.4551e-01, 1.3750e+00],\n", + " [-1.4648e-02, 2.5513e-02, -9.7046e-03, ..., -3.8086e-02,\n", + " 1.4258e-01, 1.3750e+00],\n", + " [-5.5469e-01, -3.6133e-02, -2.1582e-01, ..., 3.3203e-01,\n", + " -6.2500e-02, 1.8594e+00],\n", + " ...,\n", + " [-8.5547e-01, 2.8125e+00, -1.3203e+00, ..., 2.2812e+00,\n", + " -9.4141e-01, 2.5938e+00],\n", + " [-1.8203e+00, 1.5781e+00, -9.8047e-01, ..., 4.9219e-01,\n", + " -9.1406e-01, 1.9688e+00],\n", + " [-7.6172e-02, -5.8594e-01, -5.8594e-02, ..., -8.5938e-01,\n", + " 2.6172e-01, -3.9648e-01]],\n", + "\n", + " [[ 1.7548e-03, 2.1515e-03, 8.1787e-03, ..., -1.6504e-01,\n", + " -1.0703e+00, 1.6016e+00],\n", + " [ 9.2163e-03, -4.1504e-03, 8.1177e-03, ..., -1.7090e-01,\n", + " -1.0547e+00, 1.5859e+00],\n", + " [-4.4189e-02, -6.0547e-02, -4.7363e-02, ..., 7.6660e-02,\n", + " 6.4062e-01, -9.8047e-01],\n", + " ...,\n", + " [-4.0430e-01, 2.1777e-01, 1.9531e-01, ..., -2.4062e+00,\n", + " 1.0438e+01, -2.9531e+00],\n", + " [-2.6562e-01, -1.4648e-01, 4.1211e-01, ..., 4.3438e+00,\n", + " 7.0938e+00, -1.0500e+01],\n", + " [ 1.0205e-01, -5.8594e-02, 1.9922e-01, ..., 1.9141e+00,\n", + " 4.0938e+00, -1.0938e+01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-3.0396e-02, 4.7607e-02, 7.2327e-03, ..., 1.8066e-02,\n", + " -8.8867e-02, -5.0537e-02],\n", + " [-2.8320e-02, 4.1260e-02, 4.0283e-03, ..., 1.6479e-02,\n", + " -8.4473e-02, -5.1270e-02],\n", + " [ 4.2773e-01, -2.3340e-01, -3.2227e-02, ..., -6.2988e-02,\n", + " 5.7031e-01, 1.9531e-01],\n", + " ...,\n", + " [-3.3594e-01, -1.7871e-01, 7.1875e-01, ..., 2.5000e-01,\n", + " -9.2969e-01, -7.8735e-03],\n", + " [-8.2422e-01, -5.5859e-01, 1.0781e+00, ..., 1.8848e-01,\n", + " -1.2500e+00, 5.0000e-01],\n", + " [-8.9062e-01, -1.3984e+00, -2.1484e-02, ..., -1.0059e-01,\n", + " -1.5564e-02, -4.6484e-01]],\n", + "\n", + " [[-3.2471e-02, 7.8735e-03, 1.8677e-02, ..., -3.4668e-02,\n", + " -2.4048e-02, 1.7578e-02],\n", + " [-3.2715e-02, 5.1880e-03, 1.7944e-02, ..., -3.6377e-02,\n", + " -1.7944e-02, 1.6724e-02],\n", + " [-8.6060e-03, 1.3086e-01, -1.7383e-01, ..., -1.1865e-01,\n", + " -1.7285e-01, 6.3705e-04],\n", + " ...,\n", + " [ 8.3984e-02, 3.9258e-01, -2.1191e-01, ..., 4.6875e-01,\n", + " 6.5430e-02, 2.6367e-01],\n", + " [-9.1309e-02, -9.7656e-02, 2.7148e-01, ..., 7.3438e-01,\n", + " 3.5352e-01, 3.9062e-01],\n", + " [-1.8945e-01, 1.5430e-01, 3.7109e-01, ..., 1.6479e-02,\n", + " 5.2344e-01, -3.2422e-01]],\n", + "\n", + " [[ 3.6621e-03, -4.2114e-03, -2.2461e-02, ..., 3.3447e-02,\n", + " -1.1292e-02, -3.5667e-04],\n", + " [ 7.0572e-04, -2.9755e-03, -2.0874e-02, ..., 2.6367e-02,\n", + " -1.4465e-02, 2.8419e-04],\n", + " [-1.7578e-01, 2.3145e-01, 2.5513e-02, ..., 1.4893e-02,\n", + " -1.3477e-01, 6.2256e-02],\n", + " ...,\n", + " [-2.3438e-01, -8.5156e-01, -4.3945e-01, ..., -1.9922e+00,\n", + " 1.5547e+00, -1.7266e+00],\n", + " [-1.1797e+00, -5.0781e-01, 2.0508e-01, ..., -1.3203e+00,\n", + " 1.7031e+00, -5.3516e-01],\n", + " [ 1.4766e+00, 6.9531e-01, 2.0625e+00, ..., -5.8984e-01,\n", + " -1.4844e+00, 1.5391e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 2.6123e-02, 2.4780e-02, 6.0730e-03, ..., 1.0234e+00,\n", + " 3.1562e+00, -3.9648e-01],\n", + " [ 3.7109e-02, 1.2817e-02, 5.4016e-03, ..., 1.0156e+00,\n", + " 3.1562e+00, -3.9648e-01],\n", + " [ 2.9541e-02, 9.7656e-02, 9.8633e-02, ..., 8.7109e-01,\n", + " 2.2656e+00, -4.0820e-01],\n", + " ...,\n", + " [-6.0938e-01, -1.5156e+00, -6.9531e-01, ..., 1.3516e+00,\n", + " -1.0391e+00, -3.7891e-01],\n", + " [ 3.5156e-01, -9.5312e-01, -6.9922e-01, ..., 1.8359e+00,\n", + " -1.1328e+00, 5.1172e-01],\n", + " [-1.7212e-02, 5.3711e-02, -5.7812e-01, ..., 2.1719e+00,\n", + " 9.1016e-01, 1.2812e+00]],\n", + "\n", + " [[-1.8311e-03, 2.8076e-03, -3.7689e-03, ..., 2.1680e-01,\n", + " 6.5625e-01, -2.1875e+00],\n", + " [ 2.2888e-03, -1.2146e-02, -5.7983e-03, ..., 2.2168e-01,\n", + " 6.6016e-01, -2.2031e+00],\n", + " [-3.7109e-02, 1.1328e-01, -6.2500e-01, ..., 2.4023e-01,\n", + " 8.0469e-01, -2.7031e+00],\n", + " ...,\n", + " [ 3.0000e+00, -9.0234e-01, -3.3203e-01, ..., 1.4688e+00,\n", + " -3.4424e-02, -3.1406e+00],\n", + " [ 2.1094e-01, -9.1016e-01, 2.3438e+00, ..., 2.8438e+00,\n", + " 1.5137e-01, -4.1250e+00],\n", + " [-3.6094e+00, 5.3906e-01, 1.6250e+00, ..., 9.8145e-02,\n", + " 1.1562e+00, -4.2188e+00]],\n", + "\n", + " [[ 2.5024e-03, -1.4305e-04, -1.7242e-03, ..., 1.9824e-01,\n", + " -2.1875e-01, -5.1562e-01],\n", + " [ 4.4250e-03, -3.6812e-04, 2.1667e-03, ..., 1.9043e-01,\n", + " -2.1289e-01, -5.1172e-01],\n", + " [-2.0996e-01, 2.4219e-01, -3.6523e-01, ..., 3.2471e-02,\n", + " -2.0410e-01, -7.4219e-01],\n", + " ...,\n", + " [ 2.2969e+00, -3.1094e+00, -2.0469e+00, ..., -4.1406e-01,\n", + " -2.0996e-01, -5.3125e-01],\n", + " [-1.2656e+00, -1.0234e+00, -1.3594e+00, ..., 1.1875e+00,\n", + " -7.7734e-01, -2.5391e-01],\n", + " [-2.8281e+00, 1.1250e+00, -3.9844e-01, ..., -1.2012e-01,\n", + " -8.3594e-01, -1.1562e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 0.0289, -0.0146, -0.0486, ..., -0.0474, -0.1416, 0.0503],\n", + " [ 0.0339, -0.0123, -0.0500, ..., -0.0491, -0.1377, 0.0483],\n", + " [ 0.2969, -0.2520, -0.1660, ..., 0.1758, 0.0693, 0.0098],\n", + " ...,\n", + " [-0.1709, -0.3320, -0.0242, ..., 0.5078, -0.2969, -1.0391],\n", + " [-1.2266, -0.1982, -0.0187, ..., -0.2891, 0.2021, -0.5469],\n", + " [-0.8047, -0.2754, -0.8125, ..., -0.0182, 1.4688, 0.7266]],\n", + "\n", + " [[-0.3457, -0.0640, 0.0825, ..., 0.1035, 0.1338, 0.2002],\n", + " [-0.3535, -0.0635, 0.0889, ..., 0.1045, 0.1348, 0.2080],\n", + " [-0.2812, -0.1318, 0.3164, ..., 0.2383, -0.1602, 0.1445],\n", + " ...,\n", + " [ 0.0840, 0.7891, -0.4648, ..., 0.1514, -1.4375, -0.2129],\n", + " [ 0.2559, 0.8594, -0.0713, ..., 0.0471, -0.6133, -0.6289],\n", + " [ 0.0182, 0.6211, -0.1045, ..., 0.5156, -0.1758, -0.7188]],\n", + "\n", + " [[ 0.0583, 0.0101, -0.0330, ..., 0.0182, 0.0669, -0.0282],\n", + " [ 0.0645, 0.0122, -0.0310, ..., 0.0219, 0.0679, -0.0334],\n", + " [ 0.0840, 0.2207, -0.0109, ..., 0.0649, -0.2422, -0.1226],\n", + " ...,\n", + " [ 0.6094, -0.6445, -0.3867, ..., -1.0625, 0.9570, -0.7617],\n", + " [-0.2812, 0.1592, -0.8945, ..., 0.5117, -0.0442, -0.4434],\n", + " [ 0.1875, 0.1035, -0.0371, ..., 0.9297, -0.4199, -0.2402]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)), (tensor([[[[-1.6846e-02, 4.5776e-03, -1.5991e-02, ..., -1.9629e-01,\n", + " 4.3555e-01, 2.9375e+00],\n", + " [-3.2654e-03, -1.1658e-02, -8.6670e-03, ..., -1.9629e-01,\n", + " 4.3945e-01, 2.9375e+00],\n", + " [ 2.2754e-01, 1.8945e-01, 6.5430e-02, ..., -3.9844e-01,\n", + " 1.7969e-01, 8.1641e-01],\n", + " ...,\n", + " [ 9.6484e-01, -7.8516e-01, -1.6641e+00, ..., -7.4219e-01,\n", + " 9.0625e-01, -3.2031e+00],\n", + " [ 1.3984e+00, -2.3438e-01, -9.3750e-01, ..., -3.3984e-01,\n", + " 6.7578e-01, -3.0312e+00],\n", + " [-2.1777e-01, 1.1328e-01, 2.6562e-01, ..., -4.1797e-01,\n", + " 4.3750e-01, -1.6250e+00]],\n", + "\n", + " [[-3.5553e-03, 1.4404e-02, 2.7954e-02, ..., 1.2031e+00,\n", + " -3.4375e-01, -1.0469e+00],\n", + " [ 2.2095e-02, -2.2583e-03, 2.0752e-03, ..., 1.2109e+00,\n", + " -3.3789e-01, -1.0391e+00],\n", + " [ 9.4727e-02, 8.3984e-02, 5.0659e-03, ..., 2.1250e+00,\n", + " 1.1953e+00, 5.3516e-01],\n", + " ...,\n", + " [-8.3984e-02, -2.4902e-01, 8.3984e-01, ..., 3.4219e+00,\n", + " 1.5312e+00, 3.0312e+00],\n", + " [-1.3672e-01, -4.3359e-01, 7.7344e-01, ..., 2.1719e+00,\n", + " 1.4062e+00, 2.4219e+00],\n", + " [-2.8711e-01, 8.0469e-01, -1.3086e-01, ..., 2.7969e+00,\n", + " 3.3008e-01, 1.9453e+00]],\n", + "\n", + " [[ 2.8687e-02, 1.9455e-03, -5.1880e-03, ..., -9.8438e-01,\n", + " 1.3281e+00, -2.7344e+00],\n", + " [ 2.6611e-02, -2.0790e-04, -1.0437e-02, ..., -9.8438e-01,\n", + " 1.3281e+00, -2.7344e+00],\n", + " [-4.4434e-02, -6.4941e-02, -2.7734e-01, ..., -1.3281e+00,\n", + " -1.0547e-01, -1.2812e+00],\n", + " ...,\n", + " [ 3.0029e-02, -1.1484e+00, -3.5938e-01, ..., -4.0000e+00,\n", + " -4.1250e+00, 7.8906e-01],\n", + " [-2.4219e-01, -5.9766e-01, 6.2500e-02, ..., -3.2812e+00,\n", + " -2.2344e+00, 1.5781e+00],\n", + " [-2.0630e-02, 4.1211e-01, 5.6250e-01, ..., -1.8828e+00,\n", + " 8.2031e-01, 3.6719e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-5.8838e-02, 5.4443e-02, -1.2256e-01, ..., 5.0354e-03,\n", + " -2.1606e-02, -3.7354e-02],\n", + " [-6.0303e-02, 5.1514e-02, -1.1816e-01, ..., 6.0730e-03,\n", + " -2.0874e-02, -3.0884e-02],\n", + " [ 3.4375e-01, 3.6133e-01, -7.0312e-02, ..., -1.2500e-01,\n", + " -2.5977e-01, -1.2598e-01],\n", + " ...,\n", + " [-4.5654e-02, 9.2578e-01, -3.8672e-01, ..., -1.3086e-01,\n", + " -1.6968e-02, -1.2988e-01],\n", + " [-5.2002e-02, 3.5742e-01, -7.5781e-01, ..., 7.3047e-01,\n", + " -1.0703e+00, 2.0020e-02],\n", + " [-1.7344e+00, -5.0391e-01, 1.8359e-01, ..., 1.9775e-02,\n", + " 2.0117e-01, -1.1562e+00]],\n", + "\n", + " [[ 2.3438e-02, -1.8799e-02, -9.3262e-02, ..., 3.1982e-02,\n", + " 5.4443e-02, 7.5073e-03],\n", + " [ 2.0874e-02, -1.2878e-02, -8.8379e-02, ..., 2.6001e-02,\n", + " 5.4199e-02, -2.7466e-04],\n", + " [ 2.6758e-01, -2.7148e-01, -2.7930e-01, ..., 1.7285e-01,\n", + " 1.4453e-01, 3.7598e-02],\n", + " ...,\n", + " [ 1.0156e+00, 3.6523e-01, -1.0078e+00, ..., -3.4570e-01,\n", + " -3.9648e-01, 2.2852e-01],\n", + " [ 1.0547e+00, 9.0234e-01, 1.1016e+00, ..., 4.8242e-01,\n", + " -1.2422e+00, 1.7480e-01],\n", + " [ 1.4688e+00, 1.7383e-01, -2.9688e-01, ..., 8.6426e-02,\n", + " -3.5742e-01, 1.6719e+00]],\n", + "\n", + " [[-1.4771e-02, 3.2715e-02, -1.5320e-02, ..., 1.0559e-02,\n", + " -1.1719e-02, -2.1973e-03],\n", + " [-2.4536e-02, 2.9541e-02, -1.4709e-02, ..., 3.2501e-03,\n", + " -3.4637e-03, 4.4861e-03],\n", + " [ 2.4219e-01, 1.7871e-01, 1.1328e-01, ..., -2.5391e-01,\n", + " 1.3379e-01, 1.7383e-01],\n", + " ...,\n", + " [-8.7891e-03, 3.4375e-01, -1.3516e+00, ..., 1.8984e+00,\n", + " -3.8086e-01, 2.6172e-01],\n", + " [ 7.6562e-01, 1.0703e+00, 1.2188e+00, ..., -2.9688e-01,\n", + " 1.2695e-01, 8.5156e-01],\n", + " [-7.6953e-01, 2.0469e+00, 1.0781e+00, ..., -2.1250e+00,\n", + " 3.4766e-01, 2.0508e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-2.0027e-04, 3.2196e-03, 2.0386e-02, ..., -2.6758e-01,\n", + " -2.9883e-01, 1.4375e+00],\n", + " [ 5.1880e-03, 3.1891e-03, 1.2146e-02, ..., -2.6758e-01,\n", + " -3.1055e-01, 1.4375e+00],\n", + " [-4.4727e-01, -2.5586e-01, 1.8945e-01, ..., -2.8320e-01,\n", + " -4.5703e-01, 1.9141e+00],\n", + " ...,\n", + " [-7.8125e-01, 1.1172e+00, -5.1562e-01, ..., 1.6562e+00,\n", + " 1.1875e+00, 1.5938e+00],\n", + " [-1.5156e+00, -7.5781e-01, -4.2969e-02, ..., 8.6328e-01,\n", + " 1.0547e+00, 2.2812e+00],\n", + " [-1.6797e-01, -9.4531e-01, -1.1406e+00, ..., -7.2656e-01,\n", + " -1.9922e+00, 1.2266e+00]],\n", + "\n", + " [[-7.6904e-03, 1.8768e-03, -1.6602e-02, ..., 3.7109e-01,\n", + " -3.8757e-03, 1.3828e+00],\n", + " [ 6.6528e-03, 8.9111e-03, -2.1973e-03, ..., 3.7305e-01,\n", + " -6.1035e-03, 1.3672e+00],\n", + " [-6.0303e-02, -1.3867e-01, 6.5918e-02, ..., 7.6562e-01,\n", + " -1.9043e-01, 2.9688e-01],\n", + " ...,\n", + " [-3.1055e-01, 1.3828e+00, 8.3984e-02, ..., 1.3359e+00,\n", + " 7.6172e-01, -2.4805e-01],\n", + " [-1.1406e+00, 9.1406e-01, 4.5703e-01, ..., 1.3281e+00,\n", + " 3.4961e-01, 6.9336e-02],\n", + " [ 2.6367e-01, -3.7891e-01, -3.4766e-01, ..., -1.2891e+00,\n", + " -1.4219e+00, -2.1094e+00]],\n", + "\n", + " [[ 1.3428e-02, 1.4648e-02, 2.4780e-02, ..., 3.9844e-01,\n", + " 2.1362e-02, -1.1797e+00],\n", + " [ 4.5166e-02, 6.1646e-03, 2.4658e-02, ..., 4.0039e-01,\n", + " 2.5635e-02, -1.1953e+00],\n", + " [ 1.0498e-01, -5.3711e-03, 1.3379e-01, ..., 5.2734e-01,\n", + " -1.2988e-01, -3.0938e+00],\n", + " ...,\n", + " [ 1.9062e+00, -1.7285e-01, 6.9141e-01, ..., 2.2031e+00,\n", + " -2.3750e+00, -3.0938e+00],\n", + " [ 4.6875e-02, 6.1523e-02, -1.9141e-01, ..., 9.5312e-01,\n", + " -2.4844e+00, -3.4688e+00],\n", + " [ 1.6797e-01, -2.1094e-01, 3.7109e-01, ..., 2.0000e+00,\n", + " -7.3047e-01, -7.0312e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-5.5847e-03, 1.2817e-02, 1.3855e-02, ..., 2.4048e-02,\n", + " 1.6113e-02, 3.5095e-03],\n", + " [-1.2085e-02, 1.5076e-02, 1.6602e-02, ..., 2.3560e-02,\n", + " 1.5747e-02, -3.5286e-04],\n", + " [-1.6406e-01, -1.5332e-01, -1.8750e-01, ..., 8.7891e-02,\n", + " -1.2012e-01, -6.3477e-02],\n", + " ...,\n", + " [-2.5625e+00, -1.5469e+00, -6.1719e-01, ..., 1.6016e+00,\n", + " -1.1719e+00, -8.9453e-01],\n", + " [-8.7109e-01, -1.8652e-01, -7.5781e-01, ..., -4.9219e-01,\n", + " 1.1963e-01, -1.2695e-01],\n", + " [ 1.0156e-01, -4.0430e-01, -1.0156e-01, ..., -1.1094e+00,\n", + " -5.3516e-01, 1.1562e+00]],\n", + "\n", + " [[ 7.7209e-03, 8.4839e-03, -1.4587e-02, ..., 4.4922e-02,\n", + " 1.0596e-01, -6.3281e-01],\n", + " [ 6.1646e-03, 4.9744e-03, -1.0620e-02, ..., 4.1504e-02,\n", + " 1.0840e-01, -6.4062e-01],\n", + " [ 2.5391e-02, 3.6914e-01, 3.6328e-01, ..., -1.2598e-01,\n", + " -1.5332e-01, -1.2891e+00],\n", + " ...,\n", + " [ 9.0234e-01, 3.1836e-01, 7.5195e-02, ..., -5.9375e-01,\n", + " -1.5547e+00, -3.4062e+00],\n", + " [ 1.2578e+00, -8.2422e-01, -6.9141e-01, ..., -1.1328e+00,\n", + " -1.6484e+00, -3.3281e+00],\n", + " [ 1.6719e+00, -1.4609e+00, -9.6484e-01, ..., 6.9531e-01,\n", + " -1.2344e+00, -2.7656e+00]],\n", + "\n", + " [[ 8.0859e-01, 5.2734e-02, -3.1494e-02, ..., -1.7212e-02,\n", + " 2.4707e-01, 4.5166e-02],\n", + " [ 8.0859e-01, 5.2979e-02, -3.3936e-02, ..., -1.2573e-02,\n", + " 2.4512e-01, 4.4922e-02],\n", + " [ 4.4727e-01, 3.4912e-02, -1.4258e-01, ..., 1.1572e-01,\n", + " 1.6699e-01, 2.1973e-01],\n", + " ...,\n", + " [-6.8359e-01, 3.1641e-01, -4.3750e-01, ..., -2.5000e-01,\n", + " 6.7969e-01, 9.9609e-01],\n", + " [ 1.2695e-01, -6.4062e-01, -4.3945e-01, ..., -1.4062e+00,\n", + " 1.3379e-01, -7.3438e-01],\n", + " [-3.0273e-01, 3.4766e-01, -2.6703e-04, ..., 9.7168e-02,\n", + " 1.8203e+00, -4.3164e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 2.8076e-03, 4.6387e-03, 5.4932e-03, ..., 4.5703e-01,\n", + " -1.5547e+00, -2.8438e+00],\n", + " [ 5.9814e-03, 1.1292e-03, -1.3123e-03, ..., 4.5312e-01,\n", + " -1.5547e+00, -2.8594e+00],\n", + " [ 8.7109e-01, 5.5859e-01, 1.4062e-01, ..., 5.1562e-01,\n", + " -1.8125e+00, -2.5312e+00],\n", + " ...,\n", + " [ 3.6875e+00, 1.6172e+00, -1.2656e+00, ..., 6.7969e-01,\n", + " -1.4219e+00, 5.0000e-01],\n", + " [ 3.7812e+00, 3.1250e+00, -1.6406e+00, ..., 3.4180e-02,\n", + " -1.4844e+00, -1.4688e+00],\n", + " [ 1.1250e+00, 7.1484e-01, -1.5625e-01, ..., 7.7734e-01,\n", + " -3.0938e+00, -2.3125e+00]],\n", + "\n", + " [[ 7.7209e-03, 1.0559e-02, -4.2419e-03, ..., 2.6855e-02,\n", + " -1.3281e+00, -3.3594e-01],\n", + " [ 4.2725e-03, 1.0376e-02, -1.1292e-02, ..., 3.4424e-02,\n", + " -1.3359e+00, -3.3398e-01],\n", + " [-1.0352e-01, -1.1426e-01, -9.5703e-02, ..., 3.4766e-01,\n", + " -2.0625e+00, -3.0273e-02],\n", + " ...,\n", + " [-1.0703e+00, -1.6562e+00, -1.3086e-01, ..., -8.5938e-01,\n", + " -2.4531e+00, 1.7422e+00],\n", + " [-1.0859e+00, -7.5000e-01, -7.1094e-01, ..., -2.3125e+00,\n", + " -3.7344e+00, 8.0078e-01],\n", + " [ 5.1758e-02, -5.7031e-01, -2.0410e-01, ..., -1.0889e-01,\n", + " -3.3281e+00, -6.2109e-01]],\n", + "\n", + " [[-1.2939e-02, 2.6855e-02, 2.9053e-02, ..., -5.1250e+00,\n", + " -1.1016e+00, -8.7109e-01],\n", + " [ 5.1270e-03, 4.1016e-02, 2.2705e-02, ..., -5.1250e+00,\n", + " -1.1172e+00, -8.7500e-01],\n", + " [ 2.1387e-01, 1.4941e-01, 1.2891e-01, ..., -2.8125e+00,\n", + " -2.2188e+00, -7.8906e-01],\n", + " ...,\n", + " [-9.5312e-01, -1.3906e+00, -1.0234e+00, ..., 2.7656e+00,\n", + " -2.8281e+00, -2.0000e+00],\n", + " [-8.2812e-01, 4.3750e-01, -3.5547e-01, ..., 2.6250e+00,\n", + " -3.4688e+00, -3.1094e+00],\n", + " [ 4.7461e-01, 2.4414e-01, 4.6875e-02, ..., 2.4219e-01,\n", + " -2.7031e+00, -1.1641e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 0.0148, -0.0513, 0.0315, ..., -0.1445, -0.2539, 0.0518],\n", + " [ 0.0181, -0.0601, 0.0284, ..., -0.1436, -0.2520, 0.0466],\n", + " [ 0.0820, -0.2949, -0.3652, ..., 0.1040, -0.2480, 0.0806],\n", + " ...,\n", + " [-0.1162, 1.5938, 0.3008, ..., 0.1807, -1.0391, 0.1396],\n", + " [ 0.3477, 1.3281, 0.2451, ..., 0.0177, -0.6680, -0.5820],\n", + " [ 0.7266, 0.2129, -0.4492, ..., 1.8828, 0.1709, -0.3223]],\n", + "\n", + " [[ 0.0186, -0.0208, -0.0229, ..., -0.0427, -0.1416, -0.1172],\n", + " [ 0.0098, -0.0092, -0.0211, ..., -0.0449, -0.1465, -0.1250],\n", + " [-0.0306, 0.3340, 0.2100, ..., 0.2715, -0.1289, -0.3848],\n", + " ...,\n", + " [-1.1719, -0.4121, 0.7578, ..., -2.5938, 1.8203, 0.0469],\n", + " [-1.1172, 0.7344, 1.1641, ..., -1.6406, 0.8750, 1.3359],\n", + " [ 1.1484, 0.4395, 0.1836, ..., -0.5977, 0.2129, 0.1523]],\n", + "\n", + " [[-0.0090, -0.0276, -0.1641, ..., 0.0238, 0.0488, 0.0111],\n", + " [-0.0087, -0.0164, -0.1660, ..., 0.0208, 0.0547, 0.0144],\n", + " [-0.2539, 0.2051, 0.0383, ..., 0.1260, 0.2051, -0.1196],\n", + " ...,\n", + " [-0.8555, 0.8203, 1.5781, ..., 0.4707, 0.2656, -0.9453],\n", + " [-0.8359, 0.9805, 1.2188, ..., -0.4180, 0.3027, 0.3516],\n", + " [ 0.6758, -0.0344, -0.1523, ..., 0.5859, 0.2695, 0.4238]]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)), (tensor([[[[-2.3804e-02, -3.4485e-03, 2.4780e-02, ..., -4.3164e-01,\n", + " 2.2344e+00, -2.2656e+00],\n", + " [-4.4922e-02, -7.3242e-04, 7.6294e-03, ..., -4.2578e-01,\n", + " 2.2344e+00, -2.2500e+00],\n", + " [ 6.1035e-05, 5.6641e-02, -7.4219e-02, ..., -2.8320e-01,\n", + " 4.1562e+00, -6.9531e-01],\n", + " ...,\n", + " [-1.6797e-01, 9.6094e-01, -9.0625e-01, ..., -4.9219e-01,\n", + " 5.7812e+00, 4.5117e-01],\n", + " [ 8.6426e-02, 8.3594e-01, -4.6484e-01, ..., -1.8594e+00,\n", + " 9.0000e+00, 8.0566e-02],\n", + " [-1.0986e-01, -3.2715e-02, -1.3867e-01, ..., 1.3281e+00,\n", + " 5.3750e+00, 2.0625e+00]],\n", + "\n", + " [[ 2.0752e-02, -1.1169e-02, -7.8125e-03, ..., -3.2031e+00,\n", + " 5.7812e-01, -1.2812e+00],\n", + " [ 3.1738e-02, -1.4771e-02, -3.2959e-03, ..., -3.1875e+00,\n", + " 5.7812e-01, -1.2891e+00],\n", + " [-1.9653e-02, -1.4941e-01, 1.3574e-01, ..., -1.3984e+00,\n", + " 7.8125e-01, -1.0703e+00],\n", + " ...,\n", + " [ 1.1328e-01, -8.4375e-01, -1.5137e-01, ..., 3.5938e+00,\n", + " 1.8125e+00, -1.6016e-01],\n", + " [ 1.0078e+00, -6.8750e-01, 4.9805e-02, ..., 4.8750e+00,\n", + " 2.0938e+00, -3.3398e-01],\n", + " [ 7.1875e-01, -8.1055e-02, 6.3281e-01, ..., 2.8906e+00,\n", + " 1.3281e+00, -1.5938e+00]],\n", + "\n", + " [[ 1.1963e-02, 1.3367e-02, 7.0801e-03, ..., 7.0703e-01,\n", + " 2.6758e-01, -2.7812e+00],\n", + " [ 8.2397e-04, -1.8311e-03, 7.7515e-03, ..., 7.0312e-01,\n", + " 2.7148e-01, -2.7812e+00],\n", + " [-1.3867e-01, -1.2793e-01, 3.6621e-02, ..., -1.7700e-02,\n", + " 3.7891e-01, -3.4219e+00],\n", + " ...,\n", + " [-1.5527e-01, 1.8359e-01, -2.4609e-01, ..., 2.2500e+00,\n", + " -2.1719e+00, -1.5469e+00],\n", + " [ 3.9453e-01, -8.0469e-01, 1.0234e+00, ..., 1.8047e+00,\n", + " -6.2891e-01, -3.8125e+00],\n", + " [-2.6953e-01, -7.8613e-02, 7.0312e-02, ..., 9.4531e-01,\n", + " 1.5703e+00, -5.0000e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-8.4686e-04, -6.3965e-02, 2.0264e-02, ..., 3.1738e-02,\n", + " -6.2561e-03, -2.5269e-02],\n", + " [ 1.2146e-02, -7.2266e-02, 1.5747e-02, ..., 2.7100e-02,\n", + " -8.8501e-03, -1.0376e-02],\n", + " [ 1.6309e-01, -4.6289e-01, -8.4229e-03, ..., 3.5156e-02,\n", + " -2.9102e-01, 1.0303e-01],\n", + " ...,\n", + " [-2.8594e+00, -6.2500e-01, 1.4062e+00, ..., -5.7031e-01,\n", + " -1.6016e+00, -4.5000e+00],\n", + " [-2.3594e+00, 4.1875e+00, 2.1094e+00, ..., -9.0625e-01,\n", + " 1.7188e+00, -1.2188e+00],\n", + " [-5.9375e-01, 8.9062e-01, -4.2773e-01, ..., 2.2031e+00,\n", + " 8.8672e-01, 6.5625e-01]],\n", + "\n", + " [[ 1.8677e-02, 2.8809e-02, 1.1475e-02, ..., 3.3691e-02,\n", + " 5.2002e-02, -2.5879e-02],\n", + " [ 6.5613e-03, 2.4292e-02, -6.6528e-03, ..., 2.5391e-02,\n", + " 6.0303e-02, -3.4668e-02],\n", + " [ 6.7969e-01, 4.2969e-01, -2.2070e-01, ..., -1.6602e-02,\n", + " 3.1836e-01, -2.1680e-01],\n", + " ...,\n", + " [-3.6133e-01, -2.5625e+00, -2.2031e+00, ..., 2.4062e+00,\n", + " -2.9219e+00, 1.4219e+00],\n", + " [-1.2969e+00, -1.8652e-01, -3.0938e+00, ..., 2.4219e+00,\n", + " 1.0234e+00, -5.4297e-01],\n", + " [ 7.1875e-01, -2.0469e+00, -1.1865e-01, ..., -1.7266e+00,\n", + " 1.3906e+00, -1.0938e+00]],\n", + "\n", + " [[ 4.9072e-02, -3.5400e-02, -8.3984e-02, ..., -6.4941e-02,\n", + " -1.3916e-02, 8.8379e-02],\n", + " [ 5.6641e-02, -2.1851e-02, -8.2031e-02, ..., -7.2266e-02,\n", + " -1.6479e-02, 7.1777e-02],\n", + " [-9.5703e-02, -1.5234e-01, -1.4160e-01, ..., -5.6641e-01,\n", + " 5.1514e-02, 5.4297e-01],\n", + " ...,\n", + " [ 4.5625e+00, -2.7500e+00, 1.3965e-01, ..., -3.9307e-02,\n", + " 2.9062e+00, 2.9219e+00],\n", + " [ 3.8281e-01, -2.4805e-01, 1.3594e+00, ..., -4.5898e-01,\n", + " -7.4609e-01, 4.9805e-01],\n", + " [ 1.9375e+00, 4.5654e-02, 1.3125e+00, ..., 1.3828e+00,\n", + " -2.5781e+00, -2.1719e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-1.4587e-02, -9.9182e-04, 1.7090e-03, ..., -8.0078e-01,\n", + " -2.3750e+00, -2.0156e+00],\n", + " [ 1.2817e-02, -2.3438e-02, 2.1240e-02, ..., -7.9297e-01,\n", + " -2.3750e+00, -2.0000e+00],\n", + " [ 9.3750e-02, 6.6895e-02, 5.0293e-02, ..., -6.1719e-01,\n", + " -4.6094e-01, -2.5312e+00],\n", + " ...,\n", + " [ 3.6914e-01, 2.9102e-01, -2.9102e-01, ..., -2.6953e-01,\n", + " 1.2656e+00, -3.7656e+00],\n", + " [-3.7109e-02, 5.1562e-01, -1.4844e-01, ..., 2.7344e+00,\n", + " 4.1562e+00, -5.3750e+00],\n", + " [ 2.9492e-01, -9.2773e-03, -1.3379e-01, ..., 2.9844e+00,\n", + " 2.5312e+00, -6.9922e-01]],\n", + "\n", + " [[-1.4648e-02, 5.0354e-03, 4.2725e-03, ..., -3.5742e-01,\n", + " -3.5889e-02, -2.7031e+00],\n", + " [-2.0142e-03, 2.8381e-03, 6.3782e-03, ..., -3.6523e-01,\n", + " -4.0527e-02, -2.6875e+00],\n", + " [-3.5156e-01, -2.6367e-01, 1.9922e-01, ..., -5.1953e-01,\n", + " -1.9043e-01, -6.0547e-01],\n", + " ...,\n", + " [-8.5156e-01, -6.3672e-01, 1.4531e+00, ..., 7.0312e-01,\n", + " 1.1719e+00, 2.5469e+00],\n", + " [-1.4844e+00, -1.2969e+00, 4.1406e-01, ..., -3.9453e-01,\n", + " 4.3701e-02, 3.3438e+00],\n", + " [-7.5781e-01, -1.1016e+00, -6.2891e-01, ..., -2.1406e+00,\n", + " -2.1250e+00, 3.7188e+00]],\n", + "\n", + " [[ 1.3916e-02, 2.9907e-02, -4.2534e-04, ..., -9.4531e-01,\n", + " 2.1875e+00, -3.5742e-01],\n", + " [ 6.0059e-02, 3.0518e-02, -7.6904e-03, ..., -9.4922e-01,\n", + " 2.1875e+00, -3.5938e-01],\n", + " [-6.4453e-02, 1.6113e-01, 1.2891e-01, ..., -1.4219e+00,\n", + " 2.7031e+00, -2.3242e-01],\n", + " ...,\n", + " [ 1.0742e-01, 6.0938e-01, -4.1992e-01, ..., -3.9219e+00,\n", + " 4.5312e+00, 1.8945e-01],\n", + " [-3.2031e-01, 1.0938e+00, -9.2969e-01, ..., -4.0312e+00,\n", + " 4.2500e+00, 3.2344e+00],\n", + " [ 2.4219e-01, -6.4453e-02, -2.7344e-01, ..., -3.1094e+00,\n", + " 1.3516e+00, 4.8828e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-3.1494e-02, -1.3550e-02, 5.9814e-02, ..., 3.1494e-02,\n", + " 1.4465e-02, 5.5420e-02],\n", + " [-4.2725e-02, -2.6367e-02, 6.2500e-02, ..., 1.4709e-02,\n", + " -8.3542e-04, 5.1514e-02],\n", + " [ 4.5117e-01, 1.6309e-01, -1.6406e-01, ..., -5.2246e-02,\n", + " 4.7461e-01, -3.1641e-01],\n", + " ...,\n", + " [ 1.1250e+00, 1.8594e+00, 3.5938e+00, ..., 5.0625e+00,\n", + " 3.2656e+00, -2.6250e+00],\n", + " [-1.4922e+00, 2.0312e+00, -2.4062e+00, ..., -1.1094e+00,\n", + " 3.3438e+00, -2.3281e+00],\n", + " [ 2.6172e-01, 1.9609e+00, 3.5625e+00, ..., -7.9297e-01,\n", + " 1.1250e+00, -3.0469e-01]],\n", + "\n", + " [[ 7.8613e-02, 7.0312e-02, 1.3000e-02, ..., -3.7109e-02,\n", + " -9.0408e-04, 2.6855e-02],\n", + " [ 4.8828e-02, 7.0312e-02, 3.3203e-02, ..., -4.1260e-02,\n", + " 1.7700e-02, 2.6489e-02],\n", + " [ 2.2168e-01, 1.5527e-01, -9.9121e-02, ..., 2.1289e-01,\n", + " 1.3379e-01, 1.2891e-01],\n", + " ...,\n", + " [ 8.9844e-01, 1.6895e-01, -3.3750e+00, ..., 9.3125e+00,\n", + " -9.0234e-01, 2.7344e+00],\n", + " [-1.4297e+00, -2.7500e+00, 8.2422e-01, ..., 1.9688e+00,\n", + " 5.3906e-01, 6.0547e-01],\n", + " [-2.4531e+00, 1.9922e+00, 1.7266e+00, ..., 5.7812e-01,\n", + " -1.3281e+00, -7.6172e-02]],\n", + "\n", + " [[-2.7734e-01, -9.0332e-02, -1.0703e+00, ..., 2.9102e-01,\n", + " 1.8433e-02, 1.4453e-01],\n", + " [-2.9297e-01, -7.7148e-02, -1.0312e+00, ..., 2.9688e-01,\n", + " 4.1016e-02, 1.1914e-01],\n", + " [-2.7539e-01, -6.2891e-01, -1.2188e+00, ..., -1.7578e-01,\n", + " 2.3730e-01, -2.0898e-01],\n", + " ...,\n", + " [ 4.3750e+00, 1.9062e+00, -3.3594e+00, ..., 2.2656e+00,\n", + " 4.3750e-01, -2.5938e+00],\n", + " [-7.2266e-01, 1.8438e+00, -4.2188e+00, ..., -1.3203e+00,\n", + " 8.9844e-01, 2.2812e+00],\n", + " [-3.3906e+00, -9.8438e-01, -3.3438e+00, ..., -2.7734e-01,\n", + " -2.5195e-01, 1.8359e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-8.9111e-03, -2.0508e-02, 1.1780e-02, ..., -8.5938e-01,\n", + " -4.8828e-02, -1.0010e-01],\n", + " [ 2.0874e-02, 5.4321e-03, 1.0986e-02, ..., -8.5156e-01,\n", + " -6.0303e-02, -9.6191e-02],\n", + " [-2.4414e-01, -5.0293e-02, -3.0664e-01, ..., -2.5625e+00,\n", + " -5.4688e-01, 1.2695e-01],\n", + " ...,\n", + " [ 1.5938e+00, -2.4375e+00, -1.3047e+00, ..., -5.5000e+00,\n", + " -7.5000e-01, 2.4805e-01],\n", + " [-8.9844e-02, -1.6719e+00, 9.3750e-01, ..., -6.7500e+00,\n", + " 5.0000e-01, -5.5469e-01],\n", + " [-9.1406e-01, -3.0078e-01, 1.0391e+00, ..., -2.9531e+00,\n", + " -1.7656e+00, 8.4375e-01]],\n", + "\n", + " [[ 6.5918e-03, -1.4771e-02, 1.1292e-02, ..., 2.3340e-01,\n", + " -3.8438e+00, -9.6094e-01],\n", + " [ 2.8687e-03, -2.3926e-02, 7.9346e-03, ..., 2.3730e-01,\n", + " -3.8594e+00, -9.5703e-01],\n", + " [ 6.3672e-01, -8.8867e-02, 1.7969e-01, ..., 3.0664e-01,\n", + " -2.5000e+00, -1.1953e+00],\n", + " ...,\n", + " [ 2.4219e+00, 2.1406e+00, -6.8750e-01, ..., -1.1250e+00,\n", + " 4.2812e+00, -1.0547e+00],\n", + " [ 2.6562e+00, 2.1562e+00, -2.9688e-01, ..., 7.4219e-02,\n", + " 3.2188e+00, 1.0469e+00],\n", + " [ 2.2852e-01, 1.0234e+00, -3.7891e-01, ..., 1.5547e+00,\n", + " -1.4375e+00, -6.4062e-01]],\n", + "\n", + " [[ 1.2360e-03, 3.3264e-03, -2.8839e-03, ..., -3.7031e+00,\n", + " 4.8242e-01, -4.9023e-01],\n", + " [-3.6316e-03, 7.8125e-03, -2.4414e-03, ..., -3.6875e+00,\n", + " 4.8828e-01, -4.9609e-01],\n", + " [-7.0312e-01, -1.9043e-01, -1.9531e-01, ..., -7.4219e-01,\n", + " 5.2734e-01, -4.3945e-01],\n", + " ...,\n", + " [-2.1250e+00, -3.3203e-01, 1.6641e+00, ..., 6.0312e+00,\n", + " 8.3594e-01, -1.0000e+00],\n", + " [-2.2969e+00, -1.6172e+00, 1.5234e+00, ..., 4.8125e+00,\n", + " 7.3047e-01, -5.0391e-01],\n", + " [ 4.9805e-01, -1.0781e+00, -3.6328e-01, ..., 3.3281e+00,\n", + " 1.9727e-01, -7.1484e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 1.8921e-02, 4.6387e-02, 2.8320e-02, ..., -4.3457e-02,\n", + " 2.2949e-02, 1.2207e-02],\n", + " [ 1.6235e-02, 4.1992e-02, 2.5024e-02, ..., -4.4678e-02,\n", + " 2.2583e-02, 1.3184e-02],\n", + " [ 8.0078e-02, -2.8125e-01, 2.2559e-01, ..., -3.4961e-01,\n", + " 1.7480e-01, -1.4844e-01],\n", + " ...,\n", + " [-2.7031e+00, 1.6797e+00, 1.4141e+00, ..., -1.1719e+00,\n", + " 6.2891e-01, 1.0078e+00],\n", + " [-7.9688e-01, -2.2070e-01, 2.0469e+00, ..., -1.0234e+00,\n", + " -4.4434e-02, -1.0078e+00],\n", + " [-7.7344e-01, -1.5625e-01, 1.0938e+00, ..., 6.0547e-01,\n", + " 3.4332e-03, -7.2656e-01]],\n", + "\n", + " [[-3.7354e-02, 1.1658e-02, 6.8970e-03, ..., -8.6328e-01,\n", + " -1.9409e-02, 6.8970e-03],\n", + " [-4.9072e-02, 1.0010e-02, 2.2461e-02, ..., -8.5938e-01,\n", + " -1.0925e-02, 2.2339e-02],\n", + " [-2.3535e-01, 1.6113e-01, -2.4805e-01, ..., -1.7383e-01,\n", + " 1.2451e-01, -7.7148e-02],\n", + " ...,\n", + " [-4.5312e+00, 1.1172e+00, -1.3906e+00, ..., 1.4766e+00,\n", + " -2.4062e+00, 1.8984e+00],\n", + " [-3.0312e+00, 7.1484e-01, 2.2656e+00, ..., 2.9844e+00,\n", + " -8.5156e-01, 9.1797e-01],\n", + " [-7.8125e-01, 4.4141e-01, -2.9883e-01, ..., -1.8555e-01,\n", + " -7.2266e-02, -1.3828e+00]],\n", + "\n", + " [[-3.4912e-02, -3.1738e-02, 7.0801e-02, ..., 5.5664e-02,\n", + " 3.3936e-02, 1.7212e-02],\n", + " [-3.6865e-02, -3.4424e-02, 6.8848e-02, ..., 5.7861e-02,\n", + " 2.8931e-02, 1.2634e-02],\n", + " [ 3.4961e-01, -1.9629e-01, 2.0508e-02, ..., -7.0312e-01,\n", + " 1.6406e-01, -3.0518e-02],\n", + " ...,\n", + " [ 1.4062e+00, 2.0625e+00, -4.3438e+00, ..., 1.1250e+00,\n", + " -5.3516e-01, -2.6562e-01],\n", + " [-1.6016e+00, 2.5312e+00, -1.1484e+00, ..., -3.7695e-01,\n", + " -2.1562e+00, 1.8984e+00],\n", + " [ 5.5469e-01, 1.8516e+00, -1.2188e+00, ..., -8.3984e-01,\n", + " -1.4531e+00, -5.7812e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-5.0964e-03, 1.5625e-02, -1.4404e-02, ..., -5.4688e-02,\n", + " -1.9375e+00, 4.6875e-01],\n", + " [ 8.3618e-03, 4.1809e-03, -1.4709e-02, ..., -4.5166e-02,\n", + " -1.9375e+00, 4.5703e-01],\n", + " [-3.4180e-02, 3.5156e-01, 1.1279e-01, ..., 2.7100e-02,\n", + " -2.9219e+00, 1.4746e-01],\n", + " ...,\n", + " [ 2.0000e+00, -2.3242e-01, -8.9844e-02, ..., -2.1094e+00,\n", + " -5.0000e+00, 1.3184e-01],\n", + " [-6.4453e-01, 1.3594e+00, 3.8672e-01, ..., -8.9453e-01,\n", + " -3.7188e+00, 1.6953e+00],\n", + " [-6.0156e-01, 5.1172e-01, -6.1328e-01, ..., 3.4531e+00,\n", + " -2.3750e+00, -2.3633e-01]],\n", + "\n", + " [[ 1.6846e-02, 2.1973e-03, 1.4648e-02, ..., -5.1514e-02,\n", + " 3.0156e+00, 2.1094e+00],\n", + " [-3.6621e-03, 1.6602e-02, 2.6611e-02, ..., -4.6387e-02,\n", + " 3.0156e+00, 2.0938e+00],\n", + " [ 6.5918e-02, 4.3555e-01, -2.3242e-01, ..., -2.3438e-01,\n", + " 1.1016e+00, 3.1094e+00],\n", + " ...,\n", + " [-1.9531e+00, 1.0156e+00, 1.5469e+00, ..., -2.3594e+00,\n", + " -1.7031e+00, 2.6094e+00],\n", + " [ 1.4062e-01, 9.7656e-02, 1.8203e+00, ..., -1.4922e+00,\n", + " -3.1562e+00, 2.9531e+00],\n", + " [ 3.1250e-01, -2.4658e-02, 1.1719e+00, ..., 9.4531e-01,\n", + " -1.3672e+00, 2.5781e+00]],\n", + "\n", + " [[-7.2632e-03, 3.9978e-03, -3.3569e-04, ..., 4.3125e+00,\n", + " -1.1865e-01, -6.1328e-01],\n", + " [-1.1169e-02, 1.8921e-02, -6.1035e-03, ..., 4.3125e+00,\n", + " -1.2695e-01, -6.1328e-01],\n", + " [ 4.9609e-01, -7.8125e-03, -1.1035e-01, ..., 2.4219e+00,\n", + " -4.0234e-01, -4.1211e-01],\n", + " ...,\n", + " [ 1.7188e+00, -4.3945e-01, 2.6875e+00, ..., -4.1562e+00,\n", + " 1.5469e+00, -6.6406e-02],\n", + " [ 1.0000e+00, 1.5332e-01, 1.8984e+00, ..., -3.7344e+00,\n", + " 9.7656e-01, -9.0234e-01],\n", + " [-5.2344e-01, 3.2812e-01, 4.0234e-01, ..., 4.7461e-01,\n", + " -1.6641e+00, 1.9043e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 2.3560e-02, 7.7820e-03, -2.9907e-02, ..., -4.7119e-02,\n", + " 4.9072e-02, -1.4771e-02],\n", + " [ 4.1016e-02, 4.8218e-03, -5.5908e-02, ..., -4.2236e-02,\n", + " 2.3193e-02, -4.2114e-03],\n", + " [-1.9629e-01, 3.1641e-01, -5.5859e-01, ..., 1.4453e-01,\n", + " 9.2773e-02, -1.6211e-01],\n", + " ...,\n", + " [-2.2188e+00, 9.2578e-01, 4.7812e+00, ..., -4.9805e-01,\n", + " 6.4844e-01, -3.1641e-01],\n", + " [-8.5938e-01, -5.9570e-02, 2.7148e-01, ..., 8.8501e-03,\n", + " -8.0078e-01, 2.4707e-01],\n", + " [ 2.8320e-01, 1.1016e+00, -1.1562e+00, ..., 1.0938e+00,\n", + " -4.5508e-01, 1.7188e+00]],\n", + "\n", + " [[-9.2773e-03, 1.7578e-02, 3.6865e-02, ..., -3.5095e-03,\n", + " 1.4465e-02, 1.7456e-02],\n", + " [-1.0803e-02, 1.2360e-03, 4.4678e-02, ..., -2.1851e-02,\n", + " 1.0610e-05, 1.0620e-02],\n", + " [ 1.2207e-01, -3.8281e-01, 3.8672e-01, ..., -4.9023e-01,\n", + " 1.1572e-01, 4.5703e-01],\n", + " ...,\n", + " [ 1.1484e+00, 5.1562e+00, -3.8672e-01, ..., 1.5781e+00,\n", + " 2.9688e+00, 1.2031e+00],\n", + " [-5.9375e-01, -2.1562e+00, -5.9766e-01, ..., 5.5469e-01,\n", + " -2.4062e+00, 1.7422e+00],\n", + " [-3.1055e-01, -2.3071e-02, 1.9141e-01, ..., -5.7861e-02,\n", + " -2.3281e+00, -3.8477e-01]],\n", + "\n", + " [[ 2.8564e-02, -8.0566e-03, -1.0498e-01, ..., 4.4434e-02,\n", + " 5.0293e-02, -3.7842e-02],\n", + " [ 2.4658e-02, -1.4893e-02, -1.0791e-01, ..., 1.5076e-02,\n", + " 5.4688e-02, -2.7710e-02],\n", + " [ 3.4766e-01, 7.2327e-03, -2.8320e-01, ..., -6.3281e-01,\n", + " -1.9629e-01, -5.3516e-01],\n", + " ...,\n", + " [-3.8330e-02, 5.8203e-01, 1.9727e-01, ..., 1.3594e+00,\n", + " -8.8281e-01, -2.4062e+00],\n", + " [ 3.7500e-01, -9.3359e-01, -3.3398e-01, ..., 1.0781e+00,\n", + " -1.1094e+00, -3.8086e-01],\n", + " [-9.2578e-01, 7.1875e-01, -5.8594e-01, ..., -9.9121e-02,\n", + " -2.7344e-01, 1.3828e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 3.5248e-03, 3.8757e-03, 1.7334e-02, ..., -1.7383e-01,\n", + " 3.6914e-01, -5.4375e+00],\n", + " [-2.6855e-02, -3.2043e-03, 2.3682e-02, ..., -1.6602e-01,\n", + " 3.6719e-01, -5.4375e+00],\n", + " [ 5.5469e-01, 6.9922e-01, -2.2559e-01, ..., -8.5156e-01,\n", + " -1.7383e-01, -3.4375e+00],\n", + " ...,\n", + " [ 2.2500e+00, -1.0234e+00, 1.2422e+00, ..., 2.6758e-01,\n", + " -6.5234e-01, -4.4336e-01],\n", + " [ 1.6484e+00, 3.3984e-01, 1.6016e-01, ..., -8.0469e-01,\n", + " -1.3047e+00, 3.6328e-01],\n", + " [-8.7891e-01, -9.5215e-03, -5.5469e-01, ..., 9.8047e-01,\n", + " -4.1406e-01, -1.6562e+00]],\n", + "\n", + " [[-5.2185e-03, 5.3406e-04, -2.5558e-04, ..., -2.5879e-02,\n", + " 4.2500e+00, -2.4512e-01],\n", + " [ 6.4697e-03, 3.0518e-04, -4.4556e-03, ..., -3.8818e-02,\n", + " 4.2500e+00, -2.5586e-01],\n", + " [-1.9336e-01, -1.4844e-01, -2.7539e-01, ..., -9.7656e-02,\n", + " 2.2656e+00, 2.7344e-01],\n", + " ...,\n", + " [ 1.6875e+00, 2.5312e+00, -2.0020e-01, ..., -1.0391e+00,\n", + " -2.6406e+00, 2.9375e+00],\n", + " [-4.3555e-01, -4.3750e-01, 2.3340e-01, ..., -9.9609e-01,\n", + " -8.6328e-01, 1.1250e+00],\n", + " [-1.0234e+00, -1.1172e+00, 3.5547e-01, ..., -2.0000e+00,\n", + " 1.8047e+00, -3.5706e-03]],\n", + "\n", + " [[-1.1414e-02, -6.1951e-03, 3.0279e-05, ..., 4.5000e+00,\n", + " -7.4609e-01, -1.7969e+00],\n", + " [-1.3550e-02, -3.9673e-03, -4.9744e-03, ..., 4.4688e+00,\n", + " -7.5000e-01, -1.7969e+00],\n", + " [ 8.8281e-01, -5.7812e-01, 7.2266e-01, ..., 1.0781e+00,\n", + " -1.3281e+00, -1.9609e+00],\n", + " ...,\n", + " [ 2.2500e+00, 2.6758e-01, -1.2500e+00, ..., -4.9062e+00,\n", + " -1.3125e+00, -1.2422e+00],\n", + " [ 2.2344e+00, -7.5000e-01, -1.4609e+00, ..., -4.2500e+00,\n", + " -2.6250e+00, -1.5469e+00],\n", + " [ 2.9297e-03, -1.5625e+00, -6.5625e-01, ..., -1.4141e+00,\n", + " -1.8281e+00, -2.0625e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 8.3984e-02, -1.2512e-02, -3.1982e-02, ..., -2.2168e-01,\n", + " -4.2114e-03, -9.4727e-02],\n", + " [ 8.9355e-02, -1.0315e-02, -3.2471e-02, ..., -2.3047e-01,\n", + " -2.7618e-03, -9.6680e-02],\n", + " [ 2.9297e-01, -1.9531e-01, 8.9645e-04, ..., -3.4766e-01,\n", + " -2.0410e-01, -1.7188e-01],\n", + " ...,\n", + " [-1.7109e+00, -6.3281e-01, 4.1602e-01, ..., 1.0703e+00,\n", + " -4.4922e-01, 4.3945e-01],\n", + " [-1.1016e+00, 3.6719e-01, 1.0078e+00, ..., -1.6699e-01,\n", + " -1.9297e+00, -6.0547e-01],\n", + " [-2.3926e-01, 6.3281e-01, 1.4941e-01, ..., 8.6328e-01,\n", + " -1.0859e+00, -1.1562e+00]],\n", + "\n", + " [[ 8.8379e-02, -2.2266e-01, 8.9844e-02, ..., 1.2451e-01,\n", + " -4.0234e-01, -7.1106e-03],\n", + " [ 9.1797e-02, -2.3438e-01, 9.5703e-02, ..., 1.1475e-01,\n", + " -4.0234e-01, -1.0498e-02],\n", + " [-1.3281e-01, -7.2266e-01, -8.0078e-02, ..., -1.4648e-01,\n", + " -7.1875e-01, 3.0273e-01],\n", + " ...,\n", + " [ 1.3516e+00, 2.1406e+00, -4.1797e-01, ..., 1.2031e+00,\n", + " 1.9922e-01, 7.8516e-01],\n", + " [ 1.3281e+00, 6.4062e-01, -2.0156e+00, ..., 1.5391e+00,\n", + " -1.3750e+00, 8.9355e-02],\n", + " [-2.3633e-01, -5.0000e-01, -1.2500e+00, ..., -4.4727e-01,\n", + " -2.3750e+00, 5.8594e-01]],\n", + "\n", + " [[-1.5234e-01, -1.2598e-01, -5.8594e-02, ..., -9.7656e-02,\n", + " -1.5820e-01, 1.4746e-01],\n", + " [-1.5723e-01, -1.1865e-01, -4.9316e-02, ..., -1.0254e-01,\n", + " -1.6309e-01, 1.6113e-01],\n", + " [-3.7842e-02, -1.5503e-02, 3.1641e-01, ..., 2.0142e-02,\n", + " -2.3730e-01, 1.4648e-01],\n", + " ...,\n", + " [ 7.6953e-01, -1.0156e+00, 1.8047e+00, ..., 2.3926e-01,\n", + " 1.3359e+00, -9.1016e-01],\n", + " [ 7.4219e-01, -9.3750e-01, -2.1094e-01, ..., -2.1562e+00,\n", + " -4.1748e-02, 2.6875e+00],\n", + " [-7.6953e-01, -2.4902e-01, -1.3047e+00, ..., -1.0781e+00,\n", + " 8.0078e-01, 1.7812e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[-7.9102e-02, -7.2021e-03, -5.1758e-02, ..., -2.5469e+00,\n", + " -2.5938e+00, -1.4688e+00],\n", + " [-1.4453e-01, -1.0071e-02, -1.0596e-01, ..., -2.5625e+00,\n", + " -2.6094e+00, -1.4688e+00],\n", + " [-2.3438e-01, 1.3867e-01, -7.4707e-02, ..., -1.1172e+00,\n", + " -3.4688e+00, -1.2969e+00],\n", + " ...,\n", + " [ 2.5977e-01, 2.6562e-01, 3.8867e-01, ..., 7.5781e-01,\n", + " -3.6719e+00, -2.5625e+00],\n", + " [-7.0312e-01, 6.2500e-01, -2.0312e-01, ..., 9.6875e-01,\n", + " -4.7188e+00, -3.0469e+00],\n", + " [ 3.3691e-02, 9.1797e-02, 5.3223e-02, ..., -5.3906e-01,\n", + " -3.0000e+00, -4.2969e-01]],\n", + "\n", + " [[-1.6724e-02, -1.3123e-02, -1.1414e-02, ..., -5.9375e-01,\n", + " 1.7480e-01, -3.3398e-01],\n", + " [-6.7139e-03, -4.0894e-03, 9.8877e-03, ..., -5.8594e-01,\n", + " 1.6309e-01, -3.4766e-01],\n", + " [-3.4766e-01, -2.8125e-01, 4.1406e-01, ..., -6.5625e-01,\n", + " -8.7402e-02, -1.4062e+00],\n", + " ...,\n", + " [ 1.7500e+00, -8.0469e-01, -3.3438e+00, ..., -1.4219e+00,\n", + " 5.7031e-01, -7.8125e-01],\n", + " [-1.0391e+00, -7.8125e-01, -2.7969e+00, ..., -1.1953e+00,\n", + " 1.0938e+00, -1.6484e+00],\n", + " [-7.0312e-01, 1.5430e-01, -1.2500e+00, ..., -1.3438e+00,\n", + " 1.0938e+00, -3.6875e+00]],\n", + "\n", + " [[-2.0264e-02, 6.1646e-03, 1.2436e-03, ..., 1.1016e+00,\n", + " -3.9062e-01, 1.9766e+00],\n", + " [ 7.1411e-03, -4.5166e-03, -1.4893e-02, ..., 1.0938e+00,\n", + " -3.9062e-01, 1.9688e+00],\n", + " [ 5.1562e-01, 5.3906e-01, -4.8828e-03, ..., 1.1016e+00,\n", + " -1.9824e-01, 7.1484e-01],\n", + " ...,\n", + " [ 1.6797e+00, 1.0400e-01, 2.4414e-01, ..., 1.4160e-01,\n", + " 2.5781e-01, -2.5312e+00],\n", + " [ 9.0625e-01, 2.2656e-01, 1.8203e+00, ..., -2.5513e-02,\n", + " -3.4180e-01, -2.4688e+00],\n", + " [-1.7871e-01, 1.2598e-01, 9.1016e-01, ..., -1.4531e+00,\n", + " -2.9297e-01, 1.9688e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[-1.1292e-02, 5.7373e-02, -6.1768e-02, ..., -1.7334e-02,\n", + " -8.1177e-03, -1.4648e-02],\n", + " [-1.8799e-02, 4.7607e-02, -7.2754e-02, ..., -8.1177e-03,\n", + " -7.2632e-03, 7.1716e-03],\n", + " [-1.4551e-01, -5.1953e-01, -2.1582e-01, ..., 3.0273e-01,\n", + " 6.8848e-02, 4.0430e-01],\n", + " ...,\n", + " [-3.8281e-01, 7.5781e-01, 1.1094e+00, ..., 1.6797e+00,\n", + " 3.9453e-01, -1.9141e+00],\n", + " [-7.1484e-01, -3.8086e-01, -1.9297e+00, ..., 2.7969e+00,\n", + " -1.6719e+00, -6.9531e-01],\n", + " [-9.5703e-01, 4.5898e-01, -8.9453e-01, ..., 2.5781e+00,\n", + " 7.3047e-01, 1.0625e+00]],\n", + "\n", + " [[ 7.5195e-02, 9.9121e-02, -2.1973e-03, ..., 2.9053e-02,\n", + " 2.6123e-02, 3.2227e-02],\n", + " [ 7.1777e-02, 9.6680e-02, -9.8267e-03, ..., 1.1536e-02,\n", + " 3.2471e-02, 4.5654e-02],\n", + " [-6.9824e-02, -7.2754e-02, -2.3340e-01, ..., -4.1260e-02,\n", + " -1.4062e-01, -3.1836e-01],\n", + " ...,\n", + " [-1.9375e+00, 5.4688e-01, -1.7480e-01, ..., 3.9375e+00,\n", + " -4.6680e-01, -2.9844e+00],\n", + " [-1.2812e+00, 1.9688e+00, 6.4453e-01, ..., 2.0781e+00,\n", + " -2.5156e+00, -2.4805e-01],\n", + " [ 6.0547e-01, 2.5781e-01, 2.3281e+00, ..., -9.0234e-01,\n", + " -1.9141e+00, -1.8359e+00]],\n", + "\n", + " [[-1.1475e-01, -3.7305e-01, -7.3242e-02, ..., 4.4727e-01,\n", + " 1.1426e-01, -1.0254e-01],\n", + " [-1.1182e-01, -3.9453e-01, -7.5195e-02, ..., 4.4531e-01,\n", + " 1.0986e-01, -1.1426e-01],\n", + " [-7.9102e-02, -8.5547e-01, 2.7539e-01, ..., 1.0078e+00,\n", + " 9.8828e-01, 3.1055e-01],\n", + " ...,\n", + " [ 6.3281e-01, 6.2500e-01, -2.5000e-01, ..., -6.0303e-02,\n", + " 1.5703e+00, 2.4688e+00],\n", + " [ 1.3906e+00, -4.0234e-01, 6.2891e-01, ..., 1.6328e+00,\n", + " -2.8281e+00, 3.7031e+00],\n", + " [ 6.2500e-01, 8.0469e-01, 9.0625e-01, ..., 8.4375e-01,\n", + " -5.8984e-01, -5.7861e-02]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=)), (tensor([[[[ 1.0803e-02, -5.3024e-04, -2.5513e-02, ..., 2.7031e+00,\n", + " 3.7031e+00, 2.3906e+00],\n", + " [-2.6367e-02, 3.5156e-02, -6.0059e-02, ..., 2.7500e+00,\n", + " 3.6562e+00, 2.3750e+00],\n", + " [-1.0781e+00, -2.4414e-04, -4.8047e-01, ..., -6.1328e-01,\n", + " -5.1250e+00, -4.1602e-01],\n", + " ...,\n", + " [-1.5781e+00, 5.0391e-01, 3.2500e+00, ..., 2.5469e+00,\n", + " -3.6094e+00, 2.5156e+00],\n", + " [-1.3672e+00, -8.1250e-01, 1.9453e+00, ..., 3.9688e+00,\n", + " -2.3125e+00, 1.5625e+00],\n", + " [-1.7383e-01, -6.1328e-01, 3.3594e-01, ..., 3.2656e+00,\n", + " -5.6641e-01, 2.5000e+00]],\n", + "\n", + " [[ 1.8066e-02, 8.1787e-03, -2.2217e-02, ..., 5.0049e-02,\n", + " 4.0000e+00, 2.5000e+00],\n", + " [ 5.9814e-03, 2.7161e-03, -1.6479e-02, ..., 4.2725e-02,\n", + " 4.0938e+00, 2.5781e+00],\n", + " [ 1.0000e+00, -2.8125e-01, -1.1562e+00, ..., 8.5449e-02,\n", + " -2.6094e+00, -7.6660e-02],\n", + " ...,\n", + " [ 1.7188e-01, -2.5469e+00, -8.4766e-01, ..., -1.5918e-01,\n", + " -1.3359e+00, 1.7090e-01],\n", + " [ 2.5938e+00, -1.2812e+00, 9.5703e-01, ..., 3.7500e-01,\n", + " -1.9141e+00, 1.1797e+00],\n", + " [ 1.5234e+00, 5.8594e-01, 1.1406e+00, ..., -4.9805e-01,\n", + " 3.1875e+00, 2.6875e+00]],\n", + "\n", + " [[-2.0874e-02, 1.2390e-02, 4.8584e-02, ..., -3.3750e+00,\n", + " -1.4531e+00, -4.7500e+00],\n", + " [ 1.6235e-02, -5.6396e-02, 2.2095e-02, ..., -3.4688e+00,\n", + " -1.4609e+00, -4.8125e+00],\n", + " [-3.5938e-01, 3.6133e-01, -1.7480e-01, ..., 2.4062e+00,\n", + " 5.9766e-01, 3.8281e+00],\n", + " ...,\n", + " [-2.9688e-01, 1.7969e+00, 6.2891e-01, ..., 2.9375e+00,\n", + " -2.2031e+00, 1.3733e-04],\n", + " [-5.7422e-01, 6.2500e-01, -6.4062e-01, ..., 1.7656e+00,\n", + " -1.1094e+00, 6.2500e-01],\n", + " [-1.3086e-01, 3.3984e-01, 1.7969e-01, ..., -4.9219e-01,\n", + " -4.7266e-01, -5.4297e-01]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[[ 1.4258e-01, -1.8945e-01, 3.5938e-01, ..., 3.1982e-02,\n", + " -2.8198e-02, 4.8340e-02],\n", + " [ 1.3867e-01, -2.0801e-01, 3.5938e-01, ..., 7.7148e-02,\n", + " -1.7456e-02, 6.8848e-02],\n", + " [-7.2656e-01, 3.1641e-01, 1.5332e-01, ..., 4.1406e-01,\n", + " 5.5078e-01, -1.2256e-01],\n", + " ...,\n", + " [-6.8750e+00, -5.0000e+00, 3.8906e+00, ..., -1.4219e+00,\n", + " -3.6719e+00, -2.7344e+00],\n", + " [-1.9922e+00, -1.1406e+00, 6.9141e-01, ..., -2.6406e+00,\n", + " -3.2344e+00, -5.8984e-01],\n", + " [-2.5469e+00, 5.5859e-01, 2.8906e+00, ..., -1.6094e+00,\n", + " -3.1719e+00, 1.3965e-01]],\n", + "\n", + " [[-8.3496e-02, -4.1748e-02, 1.5234e-01, ..., 8.5449e-03,\n", + " -3.6914e-01, 2.5000e-01],\n", + " [-5.3223e-02, -9.0332e-03, 1.9434e-01, ..., 4.4632e-04,\n", + " -4.6094e-01, 2.9102e-01],\n", + " [-2.8125e-01, 5.4688e-01, 3.6328e-01, ..., -4.3701e-02,\n", + " -8.7500e-01, 1.1426e-01],\n", + " ...,\n", + " [ 1.0400e-01, -3.7842e-02, -1.1328e+00, ..., 1.6992e-01,\n", + " 1.0391e+00, 1.0391e+00],\n", + " [ 4.0039e-01, 1.8359e+00, 6.6016e-01, ..., -1.3125e+00,\n", + " -1.6797e+00, 9.6094e-01],\n", + " [-2.0156e+00, 2.6094e+00, -7.1094e-01, ..., -4.2188e-01,\n", + " 1.6406e-01, 6.5234e-01]],\n", + "\n", + " [[ 1.2305e-01, 2.9541e-02, -6.3965e-02, ..., 1.8359e-01,\n", + " -6.7444e-03, -1.0596e-01],\n", + " [ 1.7676e-01, 2.6611e-02, -1.1084e-01, ..., 2.0508e-01,\n", + " 6.7444e-03, -1.5039e-01],\n", + " [ 7.3828e-01, 3.6719e-01, -3.2812e-01, ..., -1.9897e-02,\n", + " 2.5586e-01, -1.5430e-01],\n", + " ...,\n", + " [-9.8828e-01, 3.3750e+00, 1.0859e+00, ..., -2.7344e-01,\n", + " 2.8906e+00, 3.4375e+00],\n", + " [-6.9336e-02, -5.7422e-01, 6.0938e-01, ..., 4.3438e+00,\n", + " 4.1797e-01, -3.0859e-01],\n", + " [-3.1738e-02, 1.3359e+00, 7.6172e-01, ..., 6.2891e-01,\n", + " 5.8838e-02, -1.5312e+00]]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=))), hidden_states=(tensor([[[ 0.0258, 0.0635, 0.0488, ..., 0.0162, -0.0068, -0.0124],\n", + " [ 0.1172, 0.0859, 0.1621, ..., 0.0410, 0.0569, 0.0532],\n", + " [ 0.0234, 0.0928, 0.1011, ..., 0.0069, -0.1455, 0.1001],\n", + " ...,\n", + " [ 0.3359, -0.1104, 0.0200, ..., -0.0272, 0.1738, 0.0165],\n", + " [ 0.0515, -0.0237, 0.0374, ..., 0.1221, -0.0471, -0.0310],\n", + " [-0.0815, -0.0664, 0.1406, ..., 0.0496, 0.0077, 0.1201]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-1.7344, 0.9609, -1.4922, ..., 0.8047, 1.2891, -0.3438],\n", + " [ 4.5312, 0.5312, 0.0625, ..., -0.5703, 0.0098, 0.8203],\n", + " [-3.1406, 0.0840, 1.8750, ..., 0.9141, -0.2754, 1.4766],\n", + " ...,\n", + " [ 0.3086, 0.1465, 0.9375, ..., 0.1523, 0.2617, -0.6797],\n", + " [-4.5625, -0.5664, 1.2109, ..., 0.3301, 1.9609, -0.4922],\n", + " [ 2.5156, -0.6602, 0.8945, ..., 0.4258, -0.0349, 0.2500]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-1.1328, 1.6953, 0.0078, ..., 2.1875, -1.6172, 0.8281],\n", + " [ 5.6250, 1.5547, 1.0625, ..., -0.7422, 0.0747, -0.0566],\n", + " [-3.9375, -1.0625, 1.7734, ..., 1.7734, -5.2812, 0.1846],\n", + " ...,\n", + " [ 3.0938, 1.9844, 2.7812, ..., 2.0625, 0.0430, -0.5469],\n", + " [-5.2812, -1.6094, 3.1406, ..., -0.0078, 1.5156, -1.4062],\n", + " [ 3.3438, 0.0625, 1.8750, ..., 0.4883, 0.1172, 0.2734]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-3.2188, 1.5625, 0.6172, ..., -0.3594, -2.6406, -0.0078],\n", + " [ 2.9375, 1.0781, -0.3438, ..., -3.5625, -1.6250, -0.9453],\n", + " [-6.6875, -0.3828, 1.7500, ..., -0.2109, -3.6406, -1.1094],\n", + " ...,\n", + " [ 2.8750, 0.5312, 3.2031, ..., 3.0000, 0.5898, -0.8203],\n", + " [-4.3750, -2.5312, 3.2031, ..., 0.6250, 2.1562, -3.0938],\n", + " [ 4.4062, 0.1226, 2.1719, ..., -0.7266, -0.5469, -0.3027]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-3.4688, 1.2188, 0.5547, ..., -0.2227, -3.0781, -0.3047],\n", + " [ 2.6250, 0.7969, -0.4453, ..., -3.4375, -2.2031, -1.2969],\n", + " [-6.9062, -0.5352, 1.5938, ..., -0.0483, -4.0938, -1.1484],\n", + " ...,\n", + " [ 5.5000, 1.0703, 2.6562, ..., 2.6875, 0.9844, 0.9375],\n", + " [-3.0938, -3.6719, 1.5312, ..., -0.4453, 0.1484, -1.7734],\n", + " [ 5.2500, 0.3555, 2.4219, ..., 0.1777, 0.2969, -0.7969]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-3.5781, 1.2109, 0.3809, ..., -0.1230, -3.5938, -0.0254],\n", + " [ 2.4219, 0.7656, -0.6172, ..., -3.3281, -2.4219, -1.2656],\n", + " [-7.1250, -0.6250, 1.4062, ..., -0.0264, -4.2188, -0.9727],\n", + " ...,\n", + " [ 4.6562, 2.7188, 3.5000, ..., 2.1875, 0.2266, 1.3750],\n", + " [-3.6406, -3.1250, 3.3906, ..., 1.4141, 1.3438, -0.7812],\n", + " [ 5.5938, -0.7539, 1.5469, ..., 0.8203, -0.1133, -1.3438]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-3.5312, 2.0000, 0.1611, ..., 0.2324, -3.3594, -0.1660],\n", + " [ 2.5938, 1.6562, -0.6875, ..., -2.8750, -2.0938, -1.4062],\n", + " [-7.1562, 0.3340, 1.3047, ..., 0.1367, -3.7969, -1.0000],\n", + " ...,\n", + " [ 3.9219, 2.0781, 1.4531, ..., 2.9062, 0.9531, 2.0469],\n", + " [-5.4375, -1.1250, 1.9375, ..., 0.6797, 0.2109, -1.4375],\n", + " [ 4.4062, -0.4609, 2.7656, ..., -0.5078, -2.2344, 0.7734]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-3.6875, 2.0469, 0.0186, ..., 0.3828, -3.2656, -0.2852],\n", + " [ 2.3594, 1.7109, -0.7656, ..., -2.5469, -1.9062, -1.6016],\n", + " [-7.3750, 0.3633, 1.1328, ..., 0.2334, -3.7188, -1.1953],\n", + " ...,\n", + " [ 2.9844, 0.9922, 2.8438, ..., 1.7109, 0.0781, 1.7266],\n", + " [-4.9062, -2.2812, 1.5078, ..., -0.1035, -0.9375, -0.8047],\n", + " [ 4.8750, -1.2500, 3.7031, ..., -0.7930, -0.2227, 0.8516]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-3.8438, 2.6719, -0.1641, ..., 0.5859, -3.4062, -0.4961],\n", + " [ 2.1719, 2.3438, -1.0625, ..., -2.4844, -2.0781, -1.7891],\n", + " [-7.4688, 0.8672, 0.9453, ..., 0.2275, -3.8438, -1.2500],\n", + " ...,\n", + " [ 3.0625, 1.5234, 3.0625, ..., 3.3594, -0.5195, 3.1562],\n", + " [-4.7500, -0.8438, 1.6328, ..., 1.2891, -2.3125, -0.4766],\n", + " [ 4.6250, -1.6953, 2.7031, ..., -1.3047, 0.8242, 1.4062]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-4.1875, 2.2812, 1.1016, ..., 2.0625, -4.0312, -1.3750],\n", + " [ 2.0156, 1.6562, 0.4141, ..., -1.4297, -2.2344, -3.0938],\n", + " [-7.7812, 0.2656, 2.4375, ..., 1.5234, -4.3125, -2.2344],\n", + " ...,\n", + " [ 1.1406, 2.3438, 1.5000, ..., 1.3750, -1.5781, 1.3984],\n", + " [-4.4375, -0.2461, -0.6406, ..., 0.2969, -4.0312, 0.2500],\n", + " [ 3.0938, 0.5703, 1.4141, ..., -0.2344, 1.0156, 0.6055]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-1.0312e+00, 2.0625e+00, 1.0938e-01, ..., -3.9062e-01,\n", + " -2.1875e-01, -8.9062e-01],\n", + " [ 4.9375e+00, 1.4531e+00, -9.7656e-02, ..., -3.4219e+00,\n", + " 1.2656e+00, -2.2969e+00],\n", + " [-4.5000e+00, 2.8516e-01, 1.5234e+00, ..., -1.0938e+00,\n", + " -7.1875e-01, -1.7422e+00],\n", + " ...,\n", + " [ 1.3281e-01, 4.5938e+00, 2.3438e+00, ..., 3.6406e+00,\n", + " -6.1562e+00, 6.4453e-01],\n", + " [-2.6406e+00, 5.7031e-01, -1.9531e-03, ..., -2.1094e-01,\n", + " -4.7188e+00, -1.3750e+00],\n", + " [ 3.6250e+00, 1.0547e+00, 2.0469e+00, ..., -1.1719e+00,\n", + " -3.2422e-01, 9.1797e-01]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[-1.4453, 3.9531, -2.2500, ..., -4.0312, 2.1094, 1.0312],\n", + " [ 6.5625, 3.4219, -3.7500, ..., -7.1250, 2.2969, 0.4688],\n", + " [-3.0625, 1.5078, -1.5156, ..., -2.3750, 2.9062, 1.6875],\n", + " ...,\n", + " [-1.0547, 4.7500, 3.9375, ..., 1.0000, -3.9531, 2.2812],\n", + " [-4.1250, 1.8828, 0.4375, ..., -2.0781, -4.4688, -1.0312],\n", + " [ 3.8906, 1.1250, 1.0234, ..., -0.7812, -0.4688, 0.3945]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-20.6250, 10.5625, 10.3750, ..., 30.1250, -9.8750, -10.3125],\n", + " [-10.1250, 10.5625, 9.6875, ..., 26.7500, -8.6875, -8.1250],\n", + " [ -7.5000, -7.7500, -8.8750, ..., -7.8125, 31.7500, -8.8750],\n", + " ...,\n", + " [ 1.0078, 2.0938, 4.0312, ..., 1.7578, -1.5469, 2.0938],\n", + " [ -2.7188, 0.1406, 1.0547, ..., -0.6328, 0.1875, -1.0000],\n", + " [ 2.8438, 0.7188, 0.9531, ..., 1.3750, 0.4277, -0.8008]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-20.8750, 10.8750, 10.0625, ..., 30.5000, -10.5000, -11.1875],\n", + " [-10.3750, 10.8750, 9.3750, ..., 27.1250, -9.3125, -9.0000],\n", + " [ -7.7500, -7.8125, -9.2500, ..., -7.3750, 30.8750, -9.5625],\n", + " ...,\n", + " [ 0.0859, 3.1719, 3.2188, ..., 0.4453, 0.5078, -0.2734],\n", + " [ -2.3438, 0.7812, 1.1250, ..., 0.1484, 0.4688, -2.6250],\n", + " [ 2.8281, 0.2891, 1.2344, ..., 0.8125, 0.4219, -0.1387]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-2.0875e+01, 1.1250e+01, 9.8125e+00, ..., 3.0375e+01,\n", + " -1.1500e+01, -1.1312e+01],\n", + " [-1.0250e+01, 1.1250e+01, 9.1250e+00, ..., 2.7000e+01,\n", + " -1.0312e+01, -9.1250e+00],\n", + " [-7.7812e+00, -8.0000e+00, -9.5000e+00, ..., -7.1562e+00,\n", + " 2.9875e+01, -9.5000e+00],\n", + " ...,\n", + " [ 1.5156e+00, 2.3125e+00, 4.5312e+00, ..., 8.7109e-01,\n", + " -1.3281e+00, 3.6328e-01],\n", + " [-5.4375e+00, -1.1094e+00, 3.0625e+00, ..., 3.4062e+00,\n", + " 7.8125e-03, -9.3750e-01],\n", + " [ 2.7812e+00, 6.9141e-01, 8.1641e-01, ..., -3.6719e-01,\n", + " -1.2656e+00, -3.5938e-01]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[-2.1000e+01, 1.1062e+01, 9.5625e+00, ..., 3.0125e+01,\n", + " -1.1562e+01, -1.0875e+01],\n", + " [-1.0375e+01, 1.1062e+01, 8.8750e+00, ..., 2.6750e+01,\n", + " -1.0375e+01, -8.6875e+00],\n", + " [-7.8750e+00, -8.1875e+00, -9.9375e+00, ..., -7.1875e+00,\n", + " 2.9375e+01, -9.2500e+00],\n", + " ...,\n", + " [ 1.5312e+00, 1.5156e+00, 2.4219e+00, ..., 2.5312e+00,\n", + " 4.0625e-01, -5.3125e-01],\n", + " [-5.2188e+00, -1.4375e+00, -1.0938e+00, ..., 2.5312e+00,\n", + " -3.3594e-01, -2.4688e+00],\n", + " [ 5.4688e+00, -1.5625e-02, 1.5703e+00, ..., 2.0625e+00,\n", + " 1.3906e+00, -4.8828e-01]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[-2.1375e+01, 1.0312e+01, 9.4375e+00, ..., 2.9375e+01,\n", + " -1.1438e+01, -1.0250e+01],\n", + " [-1.0750e+01, 1.0312e+01, 8.7500e+00, ..., 2.6000e+01,\n", + " -1.0250e+01, -8.0625e+00],\n", + " [-8.3750e+00, -9.0000e+00, -1.0125e+01, ..., -7.7500e+00,\n", + " 2.9500e+01, -8.6250e+00],\n", + " ...,\n", + " [ 3.5938e+00, 1.1328e+00, 1.9219e+00, ..., 7.9688e+00,\n", + " -3.3203e-01, 1.9531e-02],\n", + " [-3.0312e+00, -1.2031e+00, -2.7969e+00, ..., 5.4375e+00,\n", + " 2.0312e-01, -9.7656e-02],\n", + " [ 7.6250e+00, -1.6641e+00, 6.7969e-01, ..., 6.0000e+00,\n", + " -1.6172e+00, 5.2344e-01]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[-21.2500, 10.6875, 9.2500, ..., 28.3750, -12.0000, -10.7500],\n", + " [-10.5625, 10.6875, 8.5625, ..., 25.0000, -10.7500, -8.5625],\n", + " [ -8.1250, -8.6250, -10.1875, ..., -8.6250, 28.8750, -9.3750],\n", + " ...,\n", + " [ 3.8438, 0.9219, 3.3906, ..., 3.2812, 1.1250, -5.9062],\n", + " [ -4.0312, -2.4062, -1.9688, ..., 5.2812, -1.0078, 0.4180],\n", + " [ 7.5000, -3.2656, -0.1504, ..., 4.5938, 0.4062, 1.1562]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-21.1250, 10.5000, 8.8125, ..., 28.0000, -12.3125, -10.8750],\n", + " [-10.4375, 10.5000, 8.1250, ..., 24.6250, -11.0625, -8.6875],\n", + " [ -7.9375, -8.8750, -10.8125, ..., -8.6875, 28.6250, -9.5625],\n", + " ...,\n", + " [ 4.3125, 2.2188, 2.3125, ..., 3.2500, -0.6875, -2.6875],\n", + " [ -4.7812, -1.5078, -2.5938, ..., 4.5938, -3.0625, 1.6250],\n", + " [ 6.9375, -1.2031, -0.8125, ..., 6.0625, -3.5781, 3.9219]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-21.0000, 9.7500, 8.6250, ..., 27.8750, -13.1250, -11.4375],\n", + " [-10.2500, 9.8125, 7.9375, ..., 24.5000, -11.8750, -9.1875],\n", + " [ -7.9062, -9.5000, -10.7500, ..., -9.0625, 28.0000, -9.6250],\n", + " ...,\n", + " [ 6.2500, 1.1719, 0.2422, ..., 1.2578, -1.9297, -4.7500],\n", + " [ -4.1250, -2.5938, -3.6406, ..., 3.3438, -5.0312, -1.4688],\n", + " [ 4.7500, -0.7734, -0.9844, ..., 6.1562, -2.2188, 1.9766]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-21.3750, 8.8125, 9.0000, ..., 26.6250, -14.4375, -12.4375],\n", + " [-10.6250, 8.8750, 8.3125, ..., 23.2500, -13.1875, -10.1875],\n", + " [ -8.1875, -10.5000, -10.3125, ..., -10.1250, 26.8750, -10.4375],\n", + " ...,\n", + " [ 3.7969, 1.7344, -0.1719, ..., -2.0312, -1.4375, -2.5156],\n", + " [ -4.7500, -3.5781, -5.0938, ..., -1.3594, -4.1875, -0.3945],\n", + " [ 6.3125, 0.9023, 0.9961, ..., 3.0156, -0.2969, 1.7812]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-2.1375e+01, 8.0625e+00, 8.7500e+00, ..., 2.6875e+01,\n", + " -1.4062e+01, -1.2375e+01],\n", + " [-1.0688e+01, 8.1250e+00, 8.0625e+00, ..., 2.3500e+01,\n", + " -1.2750e+01, -1.0125e+01],\n", + " [-8.1875e+00, -1.1250e+01, -1.0312e+01, ..., -9.4375e+00,\n", + " 2.7750e+01, -1.0938e+01],\n", + " ...,\n", + " [ 6.8125e+00, 8.5938e-02, -7.3438e-01, ..., 2.7500e+00,\n", + " 7.5000e-01, -3.2500e+00],\n", + " [-5.3750e+00, -2.3750e+00, -5.0312e+00, ..., 4.8125e+00,\n", + " -4.9375e+00, -1.8359e-01],\n", + " [ 1.0938e+01, -1.5469e+00, 3.9062e-03, ..., 6.3750e+00,\n", + " -7.3438e-01, 9.0625e-01]]], device='cuda:0', dtype=torch.bfloat16,\n", + " grad_fn=), tensor([[[-21.5000, 7.8125, 8.0000, ..., 26.3750, -15.0000, -11.9375],\n", + " [-10.8750, 7.8750, 7.3438, ..., 23.0000, -13.6875, -9.6875],\n", + " [ -8.4375, -11.4375, -10.8125, ..., -9.8750, 26.7500, -10.5625],\n", + " ...,\n", + " [ 7.5312, -1.3750, 2.9062, ..., -2.4375, 1.3984, -1.4531],\n", + " [ -6.3125, -2.6094, -5.5625, ..., 2.2344, -7.4375, -3.8281],\n", + " [ 8.8750, 3.4375, -1.9844, ..., 5.8750, 3.0625, 2.6250]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-21.6250, 8.8125, 8.4375, ..., 26.7500, -15.1250, -12.1875],\n", + " [-11.0000, 8.8750, 7.7812, ..., 23.3750, -13.8125, -9.9375],\n", + " [ -8.6875, -10.5625, -10.6250, ..., -9.3750, 26.1250, -10.9375],\n", + " ...,\n", + " [ 9.3750, -1.2500, 1.4766, ..., -3.7812, 0.4297, 1.3281],\n", + " [ -5.8125, -5.5625, -4.1250, ..., 1.9844, -10.8750, -2.8906],\n", + " [ 9.0000, -0.3906, -3.1250, ..., 9.2500, 2.0625, 0.2812]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-22.5000, 9.7500, 7.4688, ..., 26.2500, -14.2500, -11.3750],\n", + " [-11.8750, 9.8750, 6.8438, ..., 22.8750, -13.0000, -9.1250],\n", + " [ -9.5000, -10.2500, -11.2500, ..., -8.5000, 26.1250, -9.8125],\n", + " ...,\n", + " [ 19.1250, -12.3125, 4.1875, ..., -18.0000, 8.6250, 9.3750],\n", + " [-17.5000, -16.7500, -8.3750, ..., -5.8750, -16.5000, -4.8438],\n", + " [ 10.0625, -10.6250, -0.3672, ..., 9.3750, -3.2969, -2.8906]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-21.6250, 9.8750, 5.8125, ..., 26.0000, -13.6875, -8.5625],\n", + " [-10.8750, 10.0000, 5.1250, ..., 22.6250, -12.3750, -6.2812],\n", + " [ -8.5000, -10.6875, -12.8750, ..., -9.0000, 26.8750, -7.3125],\n", + " ...,\n", + " [ 25.2500, -6.0625, 4.0938, ..., -23.8750, -0.1250, 22.2500],\n", + " [ -9.8750, -25.7500, -5.3125, ..., -4.9375, 0.0625, 2.0625],\n", + " [ -1.0000, -13.8125, 3.1875, ..., 7.5312, -2.8438, 2.3438]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-22.6250, 9.6250, 6.3125, ..., 25.5000, -13.5625, -8.8125],\n", + " [-11.7500, 9.7500, 5.6250, ..., 22.1250, -12.2500, -6.5938],\n", + " [ -9.2500, -10.6875, -12.6875, ..., -9.5000, 27.1250, -7.3750],\n", + " ...,\n", + " [ 7.6250, -20.3750, -6.6250, ..., -28.7500, 5.9375, 26.6250],\n", + " [-26.2500, -25.1250, -12.0000, ..., -12.5000, 11.4375, 3.6875],\n", + " [ -7.7812, -14.8750, 2.1406, ..., 6.8438, 2.6562, 5.1250]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-22.5000, 9.5625, 5.8125, ..., 25.6250, -13.7500, -8.2500],\n", + " [-11.5000, 9.6875, 5.1250, ..., 22.2500, -12.3750, -6.0000],\n", + " [ -9.1250, -11.0000, -13.3125, ..., -9.1875, 26.7500, -7.1250],\n", + " ...,\n", + " [ -4.4688, -30.7500, -6.0625, ..., -31.7500, 1.1562, 20.6250],\n", + " [-34.2500, -40.5000, -19.1250, ..., -13.8125, 18.6250, 0.9062],\n", + " [-10.0000, -10.7500, -2.4688, ..., 5.5312, 2.4688, 8.2500]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-27.8750, 5.0938, 27.5000, ..., 28.3750, 27.7500, 12.1875],\n", + " [-16.7500, 5.4688, 26.0000, ..., 24.7500, 27.6250, 13.7500],\n", + " [ -9.7500, -11.6250, -8.8750, ..., -12.5000, 40.5000, -0.9688],\n", + " ...,\n", + " [-14.3125, -31.3750, -5.1562, ..., -22.0000, 6.3750, 33.0000],\n", + " [-29.5000, -45.7500, -17.2500, ..., -20.5000, 15.5625, 3.9688],\n", + " [ -7.8750, -15.5000, -4.1875, ..., 5.7500, -0.5039, 6.6875]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-65.5000, 11.6250, 38.2500, ..., 53.5000, 53.0000, 15.5000],\n", + " [-53.7500, 12.0000, 37.0000, ..., 49.5000, 53.0000, 17.0000],\n", + " [ -8.0625, -6.2500, 1.8750, ..., -4.9062, 59.7500, -8.0000],\n", + " ...,\n", + " [ -0.8750, -25.7500, 0.5000, ..., -22.6250, 0.0938, 20.6250],\n", + " [-30.2500, -42.5000, -17.3750, ..., -16.7500, 10.3750, -0.3125],\n", + " [ -9.6250, -18.2500, -3.0938, ..., 10.0625, -1.1562, 7.7812]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), tensor([[[-0.8320, -0.2168, 0.1982, ..., 0.5664, 0.9219, -0.0074],\n", + " [-0.8867, -0.3125, 0.2402, ..., 0.6094, 1.1094, 0.1016],\n", + " [-0.4258, 0.8047, 1.3828, ..., -0.0264, 2.2812, 0.5977],\n", + " ...,\n", + " [-0.6602, -1.9922, 0.0066, ..., -0.9844, -0.2910, 2.6719],\n", + " [-1.9062, -1.7734, -0.5859, ..., -0.9258, 0.6953, 0.6016],\n", + " [-1.3516, -0.9453, 0.1040, ..., 0.1582, 0.1592, 0.2598]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=)), attentions=None)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "o" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "6f0b9cc2", + "metadata": {}, + "outputs": [], + "source": [ + "hidden_states, attention_mask, position_ids = model.speech_generator.upsample(o.hidden_states[-1], tgt_units = tgt_units)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "6df14178", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[49152, 11062, 1483, 49152, 67, 277, 604, 24184, 287, 20462,\n", + " 278, 1287, 14852, 3017, 1148, 271, 3297, 1669, 317, 294,\n", + " 37430, 39802, 9573, 89, 1800, 1057, 390, 863, 801, 287,\n", + " 20462, 1079, 105, 494, 310, 390, 863, 33856, 278, 1111,\n", + " 494, 27427, 8662, 252, 518, 268, 332, 1673]],\n", + " device='cuda:0')" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_ids['input_ids']" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "4703de63", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(252.9116, device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.speech_generator(tgt_reps = o.hidden_states[-1], labels = input_ids['input_ids'], tgt_units = tgt_units)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ac9deb3", + "metadata": {}, + "outputs": [], + "source": [ + "ctc_loss" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3.10", + "language": "python", + "name": "python3.10" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}