Skip to content

Commit

Permalink
feat: update ml system UI (datahub-project#12334)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Sikowitz <[email protected]>
Co-authored-by: RyanHolstien <[email protected]>
Co-authored-by: Shirshanka Das <[email protected]>
Co-authored-by: ryota-cloud <[email protected]>
  • Loading branch information
5 people authored Jan 29, 2025
1 parent dbd57c9 commit 47134c2
Show file tree
Hide file tree
Showing 24 changed files with 989 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import com.linkedin.common.urn.Urn;
import com.linkedin.datahub.graphql.QueryContext;
import com.linkedin.datahub.graphql.generated.MLModelGroupProperties;
import com.linkedin.datahub.graphql.generated.MLModelLineageInfo;
import com.linkedin.datahub.graphql.types.common.mappers.CustomPropertiesMapper;
import com.linkedin.datahub.graphql.types.common.mappers.TimeStampToAuditStampMapper;
import com.linkedin.datahub.graphql.types.mappers.EmbeddedModelMapper;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -33,10 +36,40 @@ public MLModelGroupProperties apply(
result.setVersion(VersionTagMapper.map(context, mlModelGroupProperties.getVersion()));
}
result.setCreatedAt(mlModelGroupProperties.getCreatedAt());
if (mlModelGroupProperties.hasCreated()) {
result.setCreated(
TimeStampToAuditStampMapper.map(context, mlModelGroupProperties.getCreated()));
}
if (mlModelGroupProperties.getName() != null) {
result.setName(mlModelGroupProperties.getName());
} else {
// backfill name from URN for backwards compatibility
result.setName(entityUrn.getEntityKey().get(1)); // indexed access is safe here
}

if (mlModelGroupProperties.hasLastModified()) {
result.setLastModified(
TimeStampToAuditStampMapper.map(context, mlModelGroupProperties.getLastModified()));
}

result.setCustomProperties(
CustomPropertiesMapper.map(mlModelGroupProperties.getCustomProperties(), entityUrn));

final MLModelLineageInfo lineageInfo = new MLModelLineageInfo();
if (mlModelGroupProperties.hasTrainingJobs()) {
lineageInfo.setTrainingJobs(
mlModelGroupProperties.getTrainingJobs().stream()
.map(urn -> urn.toString())
.collect(Collectors.toList()));
}
if (mlModelGroupProperties.hasDownstreamJobs()) {
lineageInfo.setDownstreamJobs(
mlModelGroupProperties.getDownstreamJobs().stream()
.map(urn -> urn.toString())
.collect(Collectors.toList()));
}
result.setMlModelLineageInfo(lineageInfo);

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.linkedin.common.urn.Urn;
import com.linkedin.datahub.graphql.QueryContext;
import com.linkedin.datahub.graphql.generated.MLModelGroup;
import com.linkedin.datahub.graphql.generated.MLModelLineageInfo;
import com.linkedin.datahub.graphql.generated.MLModelProperties;
import com.linkedin.datahub.graphql.types.common.mappers.CustomPropertiesMapper;
import com.linkedin.datahub.graphql.types.common.mappers.TimeStampToAuditStampMapper;
Expand Down Expand Up @@ -87,6 +88,20 @@ public MLModelProperties apply(
.collect(Collectors.toList()));
}
result.setTags(mlModelProperties.getTags());
final MLModelLineageInfo lineageInfo = new MLModelLineageInfo();
if (mlModelProperties.hasTrainingJobs()) {
lineageInfo.setTrainingJobs(
mlModelProperties.getTrainingJobs().stream()
.map(urn -> urn.toString())
.collect(Collectors.toList()));
}
if (mlModelProperties.hasDownstreamJobs()) {
lineageInfo.setDownstreamJobs(
mlModelProperties.getDownstreamJobs().stream()
.map(urn -> urn.toString())
.collect(Collectors.toList()));
}
result.setMlModelLineageInfo(lineageInfo);

return result;
}
Expand Down
29 changes: 29 additions & 0 deletions datahub-graphql-core/src/main/resources/lineage.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,32 @@ input LineageEdge {
"""
upstreamUrn: String!
}

"""
Represents lineage information for ML entities.
"""
type MLModelLineageInfo {
"""
List of jobs or processes used to train the model.
"""
trainingJobs: [String!]

"""
List of jobs or processes that use this model.
"""
downstreamJobs: [String!]
}

extend type MLModelProperties {
"""
Information related to lineage to this model group
"""
mlModelLineageInfo: MLModelLineageInfo
}

extend type MLModelGroupProperties {
"""
Information related to lineage to this model group
"""
mlModelLineageInfo: MLModelLineageInfo
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.linkedin.datahub.graphql.types.mlmodel.mappers;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;

import com.linkedin.common.urn.Urn;
import com.linkedin.ml.metadata.MLModelGroupProperties;
import java.net.URISyntaxException;
import org.testng.annotations.Test;

public class MLModelGroupPropertiesMapperTest {

@Test
public void testMapMLModelGroupProperties() throws URISyntaxException {
// Create backend ML Model Group Properties
MLModelGroupProperties input = new MLModelGroupProperties();

// Set description
input.setDescription("a ml trust model group");

// Set Name
input.setName("ML trust model group");

// Create URN
Urn groupUrn =
Urn.createFromString(
"urn:li:mlModelGroup:(urn:li:dataPlatform:sagemaker,another-group,PROD)");

// Map the properties
com.linkedin.datahub.graphql.generated.MLModelGroupProperties result =
MLModelGroupPropertiesMapper.map(null, input, groupUrn);

// Verify mapped properties
assertNotNull(result);
assertEquals(result.getDescription(), "a ml trust model group");
assertEquals(result.getName(), "ML trust model group");

// Verify lineage info is null as in the mock data
assertNotNull(result.getMlModelLineageInfo());
assertNull(result.getMlModelLineageInfo().getTrainingJobs());
assertNull(result.getMlModelLineageInfo().getDownstreamJobs());
}

@Test
public void testMapWithMinimalProperties() throws URISyntaxException {
// Create backend ML Model Group Properties with minimal information
MLModelGroupProperties input = new MLModelGroupProperties();

// Create URN
Urn groupUrn =
Urn.createFromString(
"urn:li:mlModelGroup:(urn:li:dataPlatform:sagemaker,another-group,PROD)");

// Map the properties
com.linkedin.datahub.graphql.generated.MLModelGroupProperties result =
MLModelGroupPropertiesMapper.map(null, input, groupUrn);

// Verify basic mapping with minimal properties
assertNotNull(result);
assertNull(result.getDescription());

// Verify lineage info is null
assertNotNull(result.getMlModelLineageInfo());
assertNull(result.getMlModelLineageInfo().getTrainingJobs());
assertNull(result.getMlModelLineageInfo().getDownstreamJobs());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package com.linkedin.datahub.graphql.types.mlmodel.mappers;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;

import com.linkedin.common.MLFeatureUrnArray;
import com.linkedin.common.TimeStamp;
import com.linkedin.common.VersionTag;
import com.linkedin.common.url.Url;
import com.linkedin.common.urn.MLFeatureUrn;
import com.linkedin.common.urn.MLModelUrn;
import com.linkedin.common.urn.Urn;
import com.linkedin.data.template.StringArray;
import com.linkedin.data.template.StringMap;
import com.linkedin.ml.metadata.MLHyperParam;
import com.linkedin.ml.metadata.MLHyperParamArray;
import com.linkedin.ml.metadata.MLMetric;
import com.linkedin.ml.metadata.MLMetricArray;
import com.linkedin.ml.metadata.MLModelProperties;
import java.net.URISyntaxException;
import org.testng.annotations.Test;

public class MLModelPropertiesMapperTest {

@Test
public void testMapMLModelProperties() throws URISyntaxException {
MLModelProperties input = new MLModelProperties();

// Set basic properties
input.setName("TestModel");
input.setDescription("A test ML model");
input.setType("Classification");

// Set version
VersionTag versionTag = new VersionTag();
versionTag.setVersionTag("1.0.0");
input.setVersion(versionTag);

// Set external URL
Url externalUrl = new Url("https://example.com/model");
input.setExternalUrl(externalUrl);

// Set created and last modified timestamps
TimeStamp createdTimeStamp = new TimeStamp();
createdTimeStamp.setTime(1000L);
Urn userUrn = Urn.createFromString("urn:li:corpuser:test");
createdTimeStamp.setActor(userUrn);
input.setCreated(createdTimeStamp);

TimeStamp lastModifiedTimeStamp = new TimeStamp();
lastModifiedTimeStamp.setTime(2000L);
lastModifiedTimeStamp.setActor(userUrn);
input.setLastModified(lastModifiedTimeStamp);

// Set custom properties
StringMap customProps = new StringMap();
customProps.put("key1", "value1");
customProps.put("key2", "value2");
input.setCustomProperties(customProps);

// Set hyper parameters
MLHyperParamArray hyperParams = new MLHyperParamArray();
MLHyperParam hyperParam1 = new MLHyperParam();
hyperParam1.setName("learning_rate");
hyperParam1.setValue("0.01");
hyperParams.add(hyperParam1);
input.setHyperParams(hyperParams);

// Set training metrics
MLMetricArray trainingMetrics = new MLMetricArray();
MLMetric metric1 = new MLMetric();
metric1.setName("accuracy");
metric1.setValue("0.95");
trainingMetrics.add(metric1);
input.setTrainingMetrics(trainingMetrics);

// Set ML features
MLFeatureUrnArray mlFeatures = new MLFeatureUrnArray();
MLFeatureUrn featureUrn = MLFeatureUrn.createFromString("urn:li:mlFeature:(dataset,feature)");
mlFeatures.add(featureUrn);
input.setMlFeatures(mlFeatures);

// Set tags
StringArray tags = new StringArray();
tags.add("tag1");
tags.add("tag2");
input.setTags(tags);

// Set training and downstream jobs
input.setTrainingJobs(
new com.linkedin.common.UrnArray(Urn.createFromString("urn:li:dataJob:train")));
input.setDownstreamJobs(
new com.linkedin.common.UrnArray(Urn.createFromString("urn:li:dataJob:predict")));

// Create ML Model URN
MLModelUrn modelUrn =
MLModelUrn.createFromString(
"urn:li:mlModel:(urn:li:dataPlatform:sagemaker,unittestmodel,PROD)");

// Map the properties
com.linkedin.datahub.graphql.generated.MLModelProperties result =
MLModelPropertiesMapper.map(null, input, modelUrn);

// Verify mapped properties
assertNotNull(result);
assertEquals(result.getName(), "TestModel");
assertEquals(result.getDescription(), "A test ML model");
assertEquals(result.getType(), "Classification");
assertEquals(result.getVersion(), "1.0.0");
assertEquals(result.getExternalUrl(), "https://example.com/model");

// Verify audit stamps
assertNotNull(result.getCreated());
assertEquals(result.getCreated().getTime().longValue(), 1000L);
assertEquals(result.getCreated().getActor(), userUrn.toString());

assertNotNull(result.getLastModified());
assertEquals(result.getLastModified().getTime().longValue(), 2000L);
assertEquals(result.getLastModified().getActor(), userUrn.toString());

// Verify custom properties
assertNotNull(result.getCustomProperties());

// Verify hyper parameters
assertNotNull(result.getHyperParams());
assertEquals(result.getHyperParams().size(), 1);
assertEquals(result.getHyperParams().get(0).getName(), "learning_rate");
assertEquals(result.getHyperParams().get(0).getValue(), "0.01");

// Verify training metrics
assertNotNull(result.getTrainingMetrics());
assertEquals(result.getTrainingMetrics().size(), 1);
assertEquals(result.getTrainingMetrics().get(0).getName(), "accuracy");
assertEquals(result.getTrainingMetrics().get(0).getValue(), "0.95");

// Verify ML features
assertNotNull(result.getMlFeatures());
assertEquals(result.getMlFeatures().size(), 1);
assertEquals(result.getMlFeatures().get(0), featureUrn.toString());

// Verify tags
assertNotNull(result.getTags());
assertEquals(result.getTags().get(0), "tag1");
assertEquals(result.getTags().get(1), "tag2");

// Verify lineage info
assertNotNull(result.getMlModelLineageInfo());
assertEquals(result.getMlModelLineageInfo().getTrainingJobs().size(), 1);
assertEquals(result.getMlModelLineageInfo().getTrainingJobs().get(0), "urn:li:dataJob:train");
assertEquals(result.getMlModelLineageInfo().getDownstreamJobs().size(), 1);
assertEquals(
result.getMlModelLineageInfo().getDownstreamJobs().get(0), "urn:li:dataJob:predict");
}

@Test
public void testMapWithMissingName() throws URISyntaxException {
MLModelProperties input = new MLModelProperties();
MLModelUrn modelUrn =
MLModelUrn.createFromString(
"urn:li:mlModel:(urn:li:dataPlatform:sagemaker,missingnamemodel,PROD)");

com.linkedin.datahub.graphql.generated.MLModelProperties result =
MLModelPropertiesMapper.map(null, input, modelUrn);

// Verify that name is extracted from URN when not present in input
assertEquals(result.getName(), "missingnamemodel");
}

@Test
public void testMapWithMinimalProperties() throws URISyntaxException {
MLModelProperties input = new MLModelProperties();
MLModelUrn modelUrn =
MLModelUrn.createFromString(
"urn:li:mlModel:(urn:li:dataPlatform:sagemaker,minimalmodel,PROD)");

com.linkedin.datahub.graphql.generated.MLModelProperties result =
MLModelPropertiesMapper.map(null, input, modelUrn);

// Verify basic mapping with minimal properties
assertNotNull(result);
assertEquals(result.getName(), "minimalmodel");
assertNull(result.getDescription());
assertNull(result.getType());
assertNull(result.getVersion());
}
}
1 change: 1 addition & 0 deletions datahub-web-react/src/app/entity/EntityPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export const EntityPage = ({ entityType }: Props) => {
entityType === EntityType.MlfeatureTable ||
entityType === EntityType.MlmodelGroup ||
entityType === EntityType.GlossaryTerm ||
entityType === EntityType.DataProcessInstance ||
entityType === EntityType.GlossaryNode;

return (
Expand Down
Loading

0 comments on commit 47134c2

Please sign in to comment.