目标
我的目标是对元素的掩码子集执行昂贵的操作,并用零表示其余元素。我将从一个例子开始,在这个例子中,我使用sum来代替昂贵的操作:
# shape: (3,3,2)
in = [ [ [ 1, 2 ], [ 2, 3 ], [ 3, 4 ] ],
[ [ 4, 5 ], [ 5, 6 ], [ 6, 7 ] ],
[ [ 7, 8 ], [ 8, 9 ], [ 9, 0 ] ] ]
# shape: (3,3)
mask = [ [ 0, 1, 0 ],
[ 1, 0, 0 ],
[ 0, 0, 0 ] ]
# expected sum output:
# shape: (3,3,1)
out = [ [ [ 0 ], [ 5 ], [ 0 ] ],
[ [ 9 ], [ 0 ], [ 0 ] ],
[ [ 0 ], [ 0 ], [ 0 ] ] ]
迄今为止的进展
我能够在一层之外完成这项工作。A
是我的掩码,E
是我的输入数据,N
是我在行和列轴上的元素数。
分区过程产生了使E变平的副作用,因此我需要将其重新整形为原来的行/列维度,但使用新的通道维度。
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf
import numpy as np
def test1():
N = 3
testA = [[[0., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]]]
testE = [[[[1., 2.],
[3.1, 4.1],
[5., 6.], ],
[[7.1, 8.1],
[9., 1.],
[9., 2.], ],
[[8., 3.],
[7., 4.],
[6., 5.], ], ]]
testA = np.asarray(testA).astype('float32')
testE = np.asarray(testE).astype('float32')
part1 = tf.dynamic_partition( testE, testA, 2 )
print( len( part1 ) )
print( part1[0] )
print( part1[1] )
"""
2
tf.Tensor(
[[1. 2.]
[5. 6.]
[9. 1.]
[9. 2.]
[8. 3.]
[7. 4.]
[6. 5.]], shape=(7, 2), dtype=float32)
tf.Tensor(
[[3.1 4.1]
[7.1 8.1]], shape=(2, 2), dtype=float32)
"""
sum1 = tf.math.reduce_sum( part1[1], axis=-1, keepdims=1 )
print( sum1 )
"""
tf.Tensor(
[[ 7.2 ]
[15.200001]], shape=(2, 1), dtype=float32)
"""
indices1 = [
[ 0 ],
[ 2 ],
[ 4 ],
[ 5 ],
[ 6 ],
[ 7 ],
[ 8 ],
]
indices2 = [
[ 1 ],
[ 3 ],
]
indices = [ indices1, indices2 ]
partitioned_data = [
np.zeros( shape=(7,1) ),
sum1
]
stitch1_flat = tf.dynamic_stitch( indices, partitioned_data )
print( stitch1_flat )
"""
tf.Tensor(
[ 0. 7.2 0. 15.200001 0. 0. 0.
0. 0. ], shape=(9,), dtype=float32)
"""
stitch1 = tf.reshape( stitch1_flat, (N,N,1) )
print( stitch1 )
"""
tf.Tensor(
[[[ 0. ]
[ 7.2 ]
[ 0. ]]
[[15.200001]
[ 0. ]
[ 0. ]]
[[ 0. ]
[ 0. ]
[ 0. ]]], shape=(3, 3, 1), dtype=float32)
"""
stitch1_np = stitch1.numpy()
target = np.array([[[ 0. ],
[ 7.2 ],
[ 0. ]],
[[15.200001],
[ 0. ],
[ 0. ]],
[[ 0. ],
[ 0. ],
[ 0. ]]])
np.testing.assert_almost_equal( stitch1_np, target, decimal=3 )
我需要帮助的地方
我很难将其概括为keras/tf层。我能够使分区工作,但我很难计算出缝合的正确索引。我还预计在计算另一个粘性元素的零张量的大小时会遇到麻烦。如果您能为这些症结中的任何一个提供帮助,我们将不胜感激!
我也很不擅长python,所以不要以为我故意做任何非常规的事情。我可能只是不知道有什么比这更好的了。
提前感谢!
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow as tf
import numpy as np
def test2():
class TestFlat(tf.keras.layers.Layer):
def __init__(self):
super(TestFlat, self).__init__()
self.N = -1 #size of row, column
self.S = -1 #size of input channel
def build(self, input_shape):
print( "input_shape: ", input_shape )
# TensorShape([None, 3, 3]), TensorShape([None, 3, 3, 2])]
assert( len( input_shape ) == 2 )
assert( len( input_shape[0] ) == 3 )
assert( len( input_shape[1] ) == 4 )
assert( input_shape[0][1] == input_shape[1][1] )
assert( input_shape[0][2] == input_shape[1][2] )
self.N = input_shape[0][1]
self.S = input_shape[1][3]
def call(self, inputs):
print( "inputs: ", inputs )
#[<tf.Tensor 'A_in:0' shape=(None, 3, 3) dtype=float32>,
# <tf.Tensor 'E_in:0' shape=(None, 3, 3, 2) dtype=float32>]
A = inputs[0] # mask
E = inputs[1] # data
A_int = tf.cast( A, "int32" )
part = tf.dynamic_partition( E, A_int, 2 )
print( len( part ) )
print( part[0] )
print( part[1] )
"""
2
tf.Tensor(
[[1. 2.]
[5. 6.]
[9. 1.]
[9. 2.]
[8. 3.]
[7. 4.]
[6. 5.]], shape=(7, 2), dtype=float32)
tf.Tensor(
[[3.1 4.1]
[7.1 8.1]], shape=(2, 2), dtype=float32)
"""
sum1 = tf.math.reduce_sum( part[1], axis=-1, keepdims=True )
# Okay so now we're done with the "expensive" calculation
# and we just need to merge with zeros back into our target shape of (None,N,N,1)
# Step 1: Calculate indices for stitching
# none of the rest of this works:
r = tf.range(self.N*self.N*self.S) #???
#tf.shape((None,self.N,self.N,1)) #???
s = tf.shape(E) #???
aa=tf.Variable(s) #???
aa[-1].assign( 1 ) #???
r = tf.reshape( r, s ) #???
indices = tf.dynamic_partition( r, A_int, 2 )
print( indices )
"""
partitioned_data = [
np.zeros( shape=(7,1) ),
sum1
]
"""
# Step 2: Create zero tensor
# Step 3: Stitch sum1 with zero tensor
return inputs[0] #dummy for now
N = 3
S = 2
A_in = Input(shape=(N, N), name='A_in')
E_in = Input(shape=(N, N, S), name='E_in')
out = TestFlat()( [A_in,E_in] )
model = Model(inputs=[A_in,E_in], outputs=out)
model.compile(optimizer='adam', loss='mean_squared_error')
model.summary()
testA = [[[0., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]]]
testE = [[[[1., 2.],
[3.1, 4.1],
[5., 6.], ],
[[7.1, 8.1],
[9., 1.],
[9., 2.], ],
[[8., 3.],
[7., 4.],
[6., 5.], ], ]]
testA = np.asarray(testA).astype('float32')
testE = np.asarray(testE).astype('float32')
print( model([testA,testE]) )
在了解到您实际上不需要零值的索引后,我最终找到了答案。混合的零是隐含的,你只需要在最后一批的末尾加上一些。
这很乱,但很管用请随时就如何做得更好提出建议
def test_flat_nbody_layer():
class TestFlat(tf.keras.layers.Layer):
def __init__(self):
super(TestFlat, self).__init__()
self.N = -1
self.S = -1
def build(self, input_shape):
print( "input_shape: ", input_shape )
# TensorShape([None, 3, 3]), TensorShape([None, 3, 3, 2])]
assert( len( input_shape ) == 2 )
assert( len( input_shape[0] ) == 3 )
assert( len( input_shape[1] ) == 4 )
assert( input_shape[0][1] == input_shape[1][1] )
assert( input_shape[0][2] == input_shape[1][2] )
self.N = input_shape[0][1]
self.S = input_shape[1][3]
pass
def call(self, inputs):
print( "inputs: ", inputs )
#[<tf.Tensor 'A_in:0' shape=(None, 3, 3) dtype=float32>,
# <tf.Tensor 'E_in:0' shape=(None, 3, 3, 2) dtype=float32>]
A = inputs[0]
E = inputs[1]
print( A ) #shape=(None, 3, 3)
print( E ) #shape=(None, 3, 3, 2)
A_int = tf.cast( A, "int32" )
part = tf.dynamic_partition( E, A_int, 2 )
print( len( part ) )
print( part[0] ) #shape=(None, 2)
print( part[1] ) #shape=(None, 2)
"""
2
tf.Tensor(
[[1. 2.]
[5. 6.]
[9. 1.]
[9. 2.]
[8. 3.]
[7. 4.]
[6. 5.]], shape=(7, 2), dtype=float32)
tf.Tensor(
[[3.1 4.1]
[7.1 8.1]], shape=(2, 2), dtype=float32)
"""
sum1 = tf.math.reduce_sum( part[1], axis=-1, keepdims=True )
print( sum1.shape )
x=tf.constant(self.N*self.N)
n=tf.constant(self.N)
r = tf.range(x*tf.shape(E)[0])
print( r ) #Tensor("test_flat/range:0", shape=(9,), dtype=int32)
print( "Batch Size:", tf.shape(E)[0] )
r2 = tf.reshape( r, shape=[tf.shape(E)[0],n,n] )
print( r2 ) #Tensor("test_flat/Reshape:0", shape=(1, 3, 3), dtype=int32)
condition_indices = tf.dynamic_partition( r2, A_int, 2 )
print( condition_indices )
#[<tf.Tensor 'test_flat/DynamicPartition_1:0' shape=(None,) dtype=int32>,
# <tf.Tensor 'test_flat/DynamicPartition_1:1' shape=(None,) dtype=int32>]
indices = [ condition_indices[ 1 ] ]
partitioned_data = [ sum1 ]
stitch_flat = tf.dynamic_stitch( indices, partitioned_data )
print( "stitch_flat", stitch_flat )
# Tensor("test_flat/DynamicStitch:0", shape=(None, 1), dtype=float32)
npad1 = tf.shape(E)[0] * n * n
print( "npad1", npad1 )
npad2 = tf.shape(stitch_flat)[0]
print( "npad2", npad2 )
nz = npad1 - npad2
print( "nz", nz )
zero_padding = tf.zeros(nz, dtype=stitch_flat.dtype)
print( "zeros", zero_padding )
zero_padding = tf.reshape( zero_padding, [nz,1] )
print( "zeros", zero_padding )
print( "tf.shape(stitch_flat)", tf.shape(stitch_flat) )
stitch = tf.concat([stitch_flat,zero_padding], -2 )
stitch = tf.reshape( stitch, [tf.shape(E)[0],n,n,1] )
return stitch #dummy for now
N = 3
S = 2
A_in = Input(shape=(N, N), name='A_in')
E_in = Input(shape=(N, N, S), name='E_in')
out = TestFlat()( [A_in,E_in] )
model = Model(inputs=[A_in,E_in], outputs=out)
model.compile(optimizer='adam', loss='mean_squared_error')
model.summary()
testA = [[[0., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]]]
testE = [[[[1., 2.],
[3.1, 4.1],
[5., 6.], ],
[[7.1, 8.1],
[9., 1.],
[9., 2.], ],
[[8., 3.],
[7., 4.],
[6., 5.], ], ]]
target = np.array([[[ 0. ],
[ 7.2 ],
[ 0. ]],
[[15.200001],
[ 0. ],
[ 0. ]],
[[ 0. ],
[ 0. ],
[ 0. ]]])
assert_almost_equal = np.testing.assert_almost_equal
testA = np.asarray(testA).astype('float32')
testE = np.asarray(testE).astype('float32')
batch1_pred = model([testA,testE])
print( "test1", batch1_pred )
"""
tf.Tensor(
[[ 0. ]
[ 7.2 ]
[ 0. ]
[15.200001]], shape=(4, 1), dtype=float32)
"""
assert_almost_equal( batch1_pred[0], target, decimal=3 )
testA2 = np.asarray([ testA[0], testA[0] ])
testE2 = np.asarray([ testE[0], testE[0] ])
print( "testA2", testA2.shape )
print( "testE2", testE2.shape )
"""
testA2 (2, 3, 3)
testE2 (2, 3, 3, 2)
"""
batch2_pred = model([testA2,testE2])
print( "test2", batch2_pred )
for output in batch2_pred:
assert_almost_equal( output, target, decimal=3 )