Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46679][SQL] Fix for SparkUnsupportedOperationException Not found an encoder of the type T, when using Parameterized class #48304

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ object JavaTypeInference {
encoderFor(typeVariables(tv), seenTypeSet, typeVariables)

case pt: ParameterizedType =>
encoderFor(pt.getRawType, seenTypeSet, JavaTypeUtils.getTypeArguments(pt).asScala.toMap)
val newTvs = JavaTypeUtils.getTypeArguments(pt).asScala.toMap
val allTvs = newTvs ++ typeVariables.removedAll(newTvs.keySet)
encoderFor(pt.getRawType, seenTypeSet, allTvs)

case c: Class[_] =>
if (seenTypeSet.contains(c)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst;

import java.io.Serializable;

public class JavaTypeInferenceBeans {

static class JavaBeanWithGenericsA<T> {
Expand Down Expand Up @@ -78,5 +80,92 @@ static class JavaBeanWithGenericBase extends JavaBeanWithGenerics<String, String
static class JavaBeanWithGenericHierarchy extends JavaBeanWithGenericsABC<Integer> {

}
}

static class PersonData {
private String id;

public String getId() {
return id;
}

public void setId(String id) {
this.id = id;
}
}

static class Team<P> {
P person;

public P getPerson() {
return person;
}

public void setPerson(P person) {
this.person = person;
}
}

static class Company<T> {
String name;
Team<T> team;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public Team<T> getTeam() {
return team;
}

public void setTeam(Team<T> team) {
this.team = team;
}
}

static class CompanyWrapper extends Company<PersonData> {
}

static class PersonDataSerializable extends PersonData implements Serializable {
}

static class TeamT<T extends Serializable> {
T person;

public T getPerson() {
return person;
}

public void setPerson(T person) {
this.person = person;
}
}

static class CompanyT<T extends Serializable> {
String name;
TeamT<T> team;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public TeamT<T> getTeam() {
return team;
}

public void setTeam(TeamT<T> team) {
this.team = team;
}
}

static class CompanyWrapperT extends CompanyT<PersonDataSerializable> {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.beans.{BeanProperty, BooleanBeanProperty}
import scala.reflect.{classTag, ClassTag}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.JavaTypeInferenceBeans.{JavaBeanWithGenericBase, JavaBeanWithGenericHierarchy, JavaBeanWithGenericsABC}
import org.apache.spark.sql.catalyst.JavaTypeInferenceBeans.{CompanyWrapper, CompanyWrapperT, JavaBeanWithGenericBase, JavaBeanWithGenericHierarchy, JavaBeanWithGenericsABC, PersonData, PersonDataSerializable, Team, TeamT}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, UDTCaseClass, UDTForCaseClass}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.types.{DecimalType, MapType, Metadata, StringType, StructField, StructType}
Expand Down Expand Up @@ -279,4 +279,32 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
))
assert(encoder === expected)
}

test("SPARK-46679: resolve generics with multi-level inheritance different type names") {
val encoder = JavaTypeInference.encoderFor(classOf[CompanyWrapper])
val expected =
JavaBeanEncoder(ClassTag(classOf[CompanyWrapper]), Seq(
encoderField("name", StringEncoder),
encoderField("team", JavaBeanEncoder(ClassTag(classOf[Team[PersonData]]), Seq(
encoderField("person", JavaBeanEncoder(ClassTag(classOf[PersonData]), Seq(
encoderField("id", StringEncoder)
)))
)))
))
assert(encoder === expected)
}

test("SPARK-46679: resolve generics with multi-level inheritance same type names") {
val encoder = JavaTypeInference.encoderFor(classOf[CompanyWrapperT])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent a bit of time playing around with the test bean classes, and I think the actual minimal example to reproduce the bug is

  static class Foo<T> {
    T t;

    public void setT(T t) {
      this.t = t;
    }

    public T getT() {
      return t;
    }
  }
  static class InnerWrapper<U> {
    Foo<U> foo;

    public void setFoo(Foo<U> foo) {
      this.foo = foo;
    }

    public Foo<U> getFoo() {
      return foo;
    }
  }
  static class OuterWrapper extends InnerWrapper<String> {}

while I know this test case came from the user's report, I think it would actually be easier for another developer to understand the point of this test case with the more minimal example.

I do think there is value in having the test cases with both combos of names for the type parameters still.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry ignore this, I made a mistake posting my comments to github earlier. I meant to post this, but I think the simplification done already is fine

val expected =
JavaBeanEncoder(ClassTag(classOf[CompanyWrapperT]), Seq(
encoderField("name", StringEncoder),
encoderField("team", JavaBeanEncoder(ClassTag(classOf[TeamT[PersonDataSerializable]]), Seq(
encoderField("person", JavaBeanEncoder(ClassTag(classOf[PersonDataSerializable]), Seq(
encoderField("id", StringEncoder)
)))
)))
))
assert(encoder === expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ public static class SimpleJavaBean implements Serializable {
private Map<Integer, String> g;
private Map<List<Long>, Map<String, String>> h;

private List<List<Long>> i;

public boolean isA() {
return a;
}
Expand Down Expand Up @@ -963,6 +965,14 @@ public void setH(Map<List<Long>, Map<String, String>> h) {
this.h = h;
}

public List<List<Long>> getI() {
return i;
}

public void setI(List<List<Long>> i) {
this.i = i;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand All @@ -977,7 +987,8 @@ public boolean equals(Object o) {
if (!e.equals(that.e)) return false;
if (!f.equals(that.f)) return false;
if (!g.equals(that.g)) return false;
return h.equals(that.h);
if (!h.equals(that.h)) return false;
return i.equals(that.i);

}

Expand All @@ -991,6 +1002,7 @@ public int hashCode() {
result = 31 * result + f.hashCode();
result = 31 * result + g.hashCode();
result = 31 * result + h.hashCode();
result = 31 * result + i.hashCode();
return result;
}
}
Expand Down Expand Up @@ -1080,6 +1092,10 @@ public void testJavaBeanEncoder() {
Map<List<Long>, Map<String, String>> complexMap1 = new HashMap<>();
complexMap1.put(Arrays.asList(1L, 2L), nestedMap1);
obj1.setH(complexMap1);
List<Long> nestedList1 = List.of(1L, 2L, 3L);
List<Long> nestedList2 = List.of(4L, 5L, 6L);
List<List<Long>> complexList1 = List.of(nestedList1, nestedList2);
obj1.setI(complexList1);

SimpleJavaBean obj2 = new SimpleJavaBean();
obj2.setA(false);
Expand All @@ -1098,6 +1114,10 @@ public void testJavaBeanEncoder() {
Map<List<Long>, Map<String, String>> complexMap2 = new HashMap<>();
complexMap2.put(Arrays.asList(3L, 4L), nestedMap2);
obj2.setH(complexMap2);
List<Long> nestedList3 = List.of(1L, 2L, 7L);
List<Long> nestedList4 = List.of(4L, 5L, 8L);
List<List<Long>> complexList2 = List.of(nestedList3, nestedList4);
obj2.setI(complexList2);

List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class));
Expand All @@ -1118,7 +1138,8 @@ public void testJavaBeanEncoder() {
Arrays.asList("a", "b"),
Arrays.asList(100L, null, 200L),
map1,
complexMap1});
complexMap1,
complexList1});
Row row2 = new GenericRow(new Object[]{
false,
30,
Expand All @@ -1127,7 +1148,9 @@ public void testJavaBeanEncoder() {
Arrays.asList("x", "y"),
Arrays.asList(300L, null, 400L),
map2,
complexMap2});
complexMap2,
complexList2});

StructType schema = new StructType()
.add("a", BooleanType, false)
.add("b", IntegerType, false)
Expand All @@ -1136,7 +1159,9 @@ public void testJavaBeanEncoder() {
.add("e", createArrayType(StringType))
.add("f", createArrayType(LongType))
.add("g", createMapType(IntegerType, StringType))
.add("h",createMapType(createArrayType(LongType), createMapType(StringType, StringType)));
.add("h",createMapType(createArrayType(LongType), createMapType(StringType, StringType)))
.add("i", createArrayType(createArrayType(LongType))) ;

Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema)
.as(Encoders.bean(SimpleJavaBean.class));
Assertions.assertEquals(data, ds3.collectAsList());
Expand Down